mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a076adf592 |
29
.github/workflows/publish.yaml
vendored
29
.github/workflows/publish.yaml
vendored
@@ -1,29 +0,0 @@
|
|||||||
name: release
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- 'v**'
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-publish
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build-n-publish:
|
|
||||||
runs-on: ubuntu-20.04
|
|
||||||
#if: startsWith(github.event.ref, 'refs/tags')
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
- name: Set up Python 3.10
|
|
||||||
uses: actions/setup-python@v2
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
- name: Install wheel
|
|
||||||
run: pip install wheel && pip install -r requirements.txt
|
|
||||||
- name: Build DiffSynth
|
|
||||||
run: python setup.py sdist bdist_wheel
|
|
||||||
- name: Publish package to PyPI
|
|
||||||
run: |
|
|
||||||
pip install twine
|
|
||||||
twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}
|
|
||||||
267
ExVideo_animatediff_train.py
Normal file
267
ExVideo_animatediff_train.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
import torch, json, os, imageio
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from einops import rearrange
|
||||||
|
import lightning as pl
|
||||||
|
from diffsynth import ModelManager, EnhancedDDIMScheduler, SDVideoPipeline, SDUNet, load_state_dict, SDMotionModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def lets_dance(
|
||||||
|
unet: SDUNet,
|
||||||
|
motion_modules: SDMotionModel,
|
||||||
|
sample,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
):
|
||||||
|
# 1. ControlNet (skip)
|
||||||
|
# 2. time
|
||||||
|
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||||
|
time_emb = unet.time_embedding(time_emb)
|
||||||
|
|
||||||
|
# 3. pre-process
|
||||||
|
hidden_states = unet.conv_in(sample)
|
||||||
|
text_emb = encoder_hidden_states
|
||||||
|
res_stack = [hidden_states]
|
||||||
|
|
||||||
|
# 4. blocks
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
for block_id, block in enumerate(unet.blocks):
|
||||||
|
# 4.1 UNet
|
||||||
|
if use_gradient_checkpointing:
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states, time_emb, text_emb, res_stack,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||||
|
# 4.2 AnimateDiff
|
||||||
|
if block_id in motion_modules.call_block_id:
|
||||||
|
motion_module_id = motion_modules.call_block_id[block_id]
|
||||||
|
if use_gradient_checkpointing:
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(motion_modules.motion_modules[motion_module_id]),
|
||||||
|
hidden_states, time_emb, text_emb, res_stack,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](hidden_states, time_emb, text_emb, res_stack)
|
||||||
|
|
||||||
|
# 5. output
|
||||||
|
hidden_states = unet.conv_norm_out(hidden_states)
|
||||||
|
hidden_states = unet.conv_act(hidden_states)
|
||||||
|
hidden_states = unet.conv_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TextVideoDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, base_path, metadata_path, steps_per_epoch=10000, training_shapes=[(128, 1, 128, 512, 512)]):
|
||||||
|
with open(metadata_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
self.path = [os.path.join(base_path, i["path"]) for i in metadata]
|
||||||
|
self.text = [i["text"] for i in metadata]
|
||||||
|
self.steps_per_epoch = steps_per_epoch
|
||||||
|
self.training_shapes = training_shapes
|
||||||
|
|
||||||
|
self.frame_process = []
|
||||||
|
for max_num_frames, interval, num_frames, height, width in training_shapes:
|
||||||
|
self.frame_process.append(v2.Compose([
|
||||||
|
v2.Resize(size=max(height, width), antialias=True),
|
||||||
|
v2.CenterCrop(size=(height, width)),
|
||||||
|
v2.Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]),
|
||||||
|
]))
|
||||||
|
|
||||||
|
|
||||||
|
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
|
||||||
|
reader = imageio.get_reader(file_path)
|
||||||
|
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
||||||
|
reader.close()
|
||||||
|
return None
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
for frame_id in range(num_frames):
|
||||||
|
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||||
|
frame = torch.tensor(frame, dtype=torch.float32)
|
||||||
|
frame = rearrange(frame, "H W C -> 1 C H W")
|
||||||
|
frame = frame_process(frame)
|
||||||
|
frames.append(frame)
|
||||||
|
reader.close()
|
||||||
|
|
||||||
|
frames = torch.concat(frames, dim=0)
|
||||||
|
frames = rearrange(frames, "T C H W -> C T H W")
|
||||||
|
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def load_video(self, file_path, training_shape_id):
|
||||||
|
data = {}
|
||||||
|
max_num_frames, interval, num_frames, height, width = self.training_shapes[training_shape_id]
|
||||||
|
frame_process = self.frame_process[training_shape_id]
|
||||||
|
start_frame_id = torch.randint(0, max_num_frames - (num_frames - 1) * interval, (1,))[0]
|
||||||
|
frames = self.load_frames_using_imageio(file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process)
|
||||||
|
if frames is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
data[f"frames_{training_shape_id}"] = frames
|
||||||
|
data[f"start_frame_id_{training_shape_id}"] = start_frame_id
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
video_data = {}
|
||||||
|
for training_shape_id in range(len(self.training_shapes)):
|
||||||
|
while True:
|
||||||
|
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||||
|
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||||
|
text = self.text[data_id]
|
||||||
|
if isinstance(text, list):
|
||||||
|
text = text[torch.randint(0, len(text), (1,))[0]]
|
||||||
|
video_file = self.path[data_id]
|
||||||
|
try:
|
||||||
|
data = self.load_video(video_file, training_shape_id)
|
||||||
|
except:
|
||||||
|
data = None
|
||||||
|
if data is not None:
|
||||||
|
data[f"text_{training_shape_id}"] = text
|
||||||
|
break
|
||||||
|
video_data.update(data)
|
||||||
|
return video_data
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.steps_per_epoch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LightningModel(pl.LightningModule):
|
||||||
|
def __init__(self, learning_rate=1e-5, sd_ckpt_path=None):
|
||||||
|
super().__init__()
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.float16, device="cpu")
|
||||||
|
model_manager.load_stable_diffusion(load_state_dict(sd_ckpt_path))
|
||||||
|
|
||||||
|
# Initialize motion modules
|
||||||
|
model_manager.model["motion_modules"] = SDMotionModel().to(dtype=self.dtype, device=self.device)
|
||||||
|
|
||||||
|
# Build pipeline
|
||||||
|
self.pipe = SDVideoPipeline.from_model_manager(model_manager)
|
||||||
|
self.pipe.vae_encoder.eval()
|
||||||
|
self.pipe.vae_encoder.requires_grad_(False)
|
||||||
|
|
||||||
|
self.pipe.vae_decoder.eval()
|
||||||
|
self.pipe.vae_decoder.requires_grad_(False)
|
||||||
|
|
||||||
|
self.pipe.text_encoder.eval()
|
||||||
|
self.pipe.text_encoder.requires_grad_(False)
|
||||||
|
|
||||||
|
self.pipe.unet.eval()
|
||||||
|
self.pipe.unet.requires_grad_(False)
|
||||||
|
|
||||||
|
self.pipe.motion_modules.train()
|
||||||
|
self.pipe.motion_modules.requires_grad_(True)
|
||||||
|
|
||||||
|
# Reset the scheduler
|
||||||
|
self.pipe.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
|
||||||
|
self.pipe.scheduler.set_timesteps(1000)
|
||||||
|
|
||||||
|
# Other parameters
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
|
||||||
|
|
||||||
|
def encode_video_with_vae(self, video):
|
||||||
|
video = video.to(device=self.device, dtype=self.dtype)
|
||||||
|
video = video.unsqueeze(0)
|
||||||
|
latents = self.pipe.vae_encoder.encode_video(video, batch_size=16)
|
||||||
|
latents = rearrange(latents[0], "C T H W -> T C H W")
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_loss(self, prompt, frames):
|
||||||
|
with torch.no_grad():
|
||||||
|
# Call video encoder
|
||||||
|
latents = self.encode_video_with_vae(frames)
|
||||||
|
|
||||||
|
# Call text encoder
|
||||||
|
prompt_embs = self.pipe.prompter.encode_prompt(self.pipe.text_encoder, prompt, device=self.device, max_length=77)
|
||||||
|
prompt_embs = prompt_embs.repeat(latents.shape[0], 1, 1)
|
||||||
|
|
||||||
|
# Call scheduler
|
||||||
|
timestep = torch.randint(0, len(self.pipe.scheduler.timesteps), (1,), device=self.device)[0]
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||||
|
|
||||||
|
# Calculate loss
|
||||||
|
model_pred = lets_dance(
|
||||||
|
self.pipe.unet, self.pipe.motion_modules,
|
||||||
|
sample=noisy_latents, encoder_hidden_states=prompt_embs, timestep=timestep
|
||||||
|
)
|
||||||
|
loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# Loss
|
||||||
|
frames = batch["frames_0"][0]
|
||||||
|
prompt = batch["text_0"][0]
|
||||||
|
loss = self.calculate_loss(prompt, frames)
|
||||||
|
|
||||||
|
# Record log
|
||||||
|
self.log("train_loss", loss, prog_bar=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.AdamW(self.pipe.motion_modules.parameters(), lr=self.learning_rate)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.motion_modules.named_parameters()))
|
||||||
|
trainable_param_names = [named_param[0] for named_param in trainable_param_names]
|
||||||
|
checkpoint["trainable_param_names"] = trainable_param_names
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# dataset and data loader
|
||||||
|
dataset = TextVideoDataset(
|
||||||
|
"/data/zhongjie/datasets/opensoraplan/data/processed",
|
||||||
|
"/data/zhongjie/datasets/opensoraplan/data/processed/metadata.json",
|
||||||
|
training_shapes=[(16, 1, 16, 512, 512)],
|
||||||
|
steps_per_epoch=7*10000,
|
||||||
|
)
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=4
|
||||||
|
)
|
||||||
|
|
||||||
|
# model
|
||||||
|
model = LightningModel(
|
||||||
|
learning_rate=1e-5,
|
||||||
|
sd_ckpt_path="models/stable_diffusion/v1-5-pruned-emaonly.safetensors",
|
||||||
|
)
|
||||||
|
|
||||||
|
# train
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
max_epochs=100000,
|
||||||
|
accelerator="gpu",
|
||||||
|
devices="auto",
|
||||||
|
strategy="deepspeed_stage_1",
|
||||||
|
precision="16-mixed",
|
||||||
|
default_root_dir="/data/zhongjie/models/train_extended_animatediff",
|
||||||
|
accumulate_grad_batches=1,
|
||||||
|
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
|
||||||
|
)
|
||||||
|
trainer.fit(
|
||||||
|
model=model,
|
||||||
|
train_dataloaders=train_loader,
|
||||||
|
ckpt_path=None
|
||||||
|
)
|
||||||
249
README.md
249
README.md
@@ -1,169 +1,92 @@
|
|||||||
# DiffSynth Studio
|
# DiffSynth Studio
|
||||||
[](https://pypi.org/project/DiffSynth/)
|
|
||||||
[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
|
|
||||||
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
|
||||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
|
||||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
|
|
||||||
|
|
||||||
## Introduction
|
## Introduction
|
||||||
|
|
||||||
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
||||||
|
|
||||||
Until now, DiffSynth Studio has supported the following models:
|
## Roadmap
|
||||||
|
|
||||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo)
|
|
||||||
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
|
||||||
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
|
||||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
|
||||||
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
|
|
||||||
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
|
|
||||||
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
|
||||||
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
|
||||||
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
|
||||||
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
|
||||||
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
|
||||||
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
|
||||||
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
|
||||||
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
|
||||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
|
||||||
|
|
||||||
## News
|
|
||||||
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
|
|
||||||
* Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
|
||||||
* Github: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
|
|
||||||
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
|
||||||
* Training dataset: Coming soon
|
|
||||||
|
|
||||||
- **December 19, 2024** We implement advanced VRAM management for HunyuanVideo, making it possible to generate videos at a resolution of 129x720x1280 using 24GB of VRAM, or at 129x512x384 resolution with just 6GB of VRAM. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
|
||||||
|
|
||||||
- **December 18, 2024** We propose ArtAug, an approach designed to improve text-to-image synthesis models through synthesis-understanding interactions. We have trained an ArtAug enhancement module for FLUX.1-dev in the format of LoRA. This model integrates the aesthetic understanding of Qwen2-VL-72B into FLUX.1-dev, leading to an improvement in the quality of generated images.
|
|
||||||
- Paper: https://arxiv.org/abs/2412.12888
|
|
||||||
- Examples: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
|
|
||||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
|
||||||
- Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (Coming soon)
|
|
||||||
|
|
||||||
- **October 25, 2024** We provide extensive FLUX ControlNet support. This project supports many different ControlNet models that can be freely combined, even if their structures differ. Additionally, ControlNet models are compatible with high-resolution refinement and partition control techniques, enabling very powerful controllable image generation. See [`./examples/ControlNet/`](./examples/ControlNet/).
|
|
||||||
|
|
||||||
- **October 8, 2024.** We release the extended LoRA based on CogVideoX-5B and ExVideo. You can download this model from [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) or [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1).
|
|
||||||
|
|
||||||
- **August 22, 2024.** CogVideoX-5B is supported in this project. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including
|
|
||||||
- Text to video
|
|
||||||
- Video editing
|
|
||||||
- Self-upscaling
|
|
||||||
- Video interpolation
|
|
||||||
|
|
||||||
- **August 22, 2024.** We have implemented an interesting painter that supports all text-to-image models. Now you can create stunning images using the painter, with assistance from AI!
|
|
||||||
- Use it in our [WebUI](#usage-in-webui).
|
|
||||||
|
|
||||||
- **August 21, 2024.** FLUX is supported in DiffSynth-Studio.
|
|
||||||
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
|
|
||||||
- LoRA, ControlNet, and additional models will be available soon.
|
|
||||||
|
|
||||||
- **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
|
||||||
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
|
||||||
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
|
||||||
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
|
||||||
- Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
|
|
||||||
- You can try ExVideo in this [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
|
|
||||||
|
|
||||||
- **June 13, 2024.** DiffSynth Studio is transferred to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
|
|
||||||
|
|
||||||
- **Jan 29, 2024.** We propose Diffutoon, a fantastic solution for toon shading.
|
|
||||||
- [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
|
||||||
- The source codes are released in this project.
|
|
||||||
- The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
|
|
||||||
|
|
||||||
- **Dec 8, 2023.** We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
|
|
||||||
|
|
||||||
- **Nov 15, 2023.** We propose FastBlend, a powerful video deflickering algorithm.
|
|
||||||
- The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
|
|
||||||
- Demo videos are shown on Bilibili, including three tasks.
|
|
||||||
- [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
|
|
||||||
- [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
|
|
||||||
- [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
|
||||||
- The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
|
|
||||||
- An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
|
|
||||||
|
|
||||||
- **Oct 1, 2023.** We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
|
|
||||||
- The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
|
|
||||||
- FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
|
|
||||||
- The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
|
|
||||||
- The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
|
|
||||||
- A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
|
|
||||||
- Since OLSS requires additional training, we don't implement it in this project.
|
|
||||||
|
|
||||||
- **Aug 29, 2023.** We propose DiffSynth, a video synthesis framework.
|
|
||||||
- [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
|
|
||||||
- The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
|
||||||
- The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
|
||||||
|
|
||||||
|
* Aug 29, 2023. We propose DiffSynth, a video synthesis framework.
|
||||||
|
* [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
|
||||||
|
* The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
||||||
|
* The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
||||||
|
* Oct 1, 2023. We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
|
||||||
|
* The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
|
||||||
|
* FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
|
||||||
|
* The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
|
||||||
|
* The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
|
||||||
|
* A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
|
||||||
|
* Since OLSS requires additional training, we don't implement it in this project.
|
||||||
|
* Nov 15, 2023. We propose FastBlend, a powerful video deflickering algorithm.
|
||||||
|
* The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
|
||||||
|
* Demo videos are shown on Bilibili, including three tasks.
|
||||||
|
* [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||||
|
* [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||||
|
* [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||||
|
* The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
|
||||||
|
* An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
|
||||||
|
* Dec 8, 2023. We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
|
||||||
|
* Jan 29, 2024. We propose Diffutoon, a fantastic solution for toon shading.
|
||||||
|
* [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/).
|
||||||
|
* The source codes are released in this project.
|
||||||
|
* The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
|
||||||
|
* June 13, 2024. DiffSynth Studio is transfered to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
|
||||||
|
* June 21, 2024. We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||||
|
* [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/).
|
||||||
|
* Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
||||||
|
* Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
||||||
|
* Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
|
||||||
|
* Until now, DiffSynth Studio has supported the following models:
|
||||||
|
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||||
|
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||||
|
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
||||||
|
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
||||||
|
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
||||||
|
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||||
|
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
||||||
|
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
||||||
|
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
||||||
|
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
Install from source code (recommended):
|
Create Python environment:
|
||||||
|
|
||||||
```
|
```
|
||||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
conda env create -f environment.yml
|
||||||
cd DiffSynth-Studio
|
|
||||||
pip install -e .
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Or install from pypi:
|
We find that sometimes `conda` cannot install `cupy` correctly, please install it manually. See [this document](https://docs.cupy.dev/en/stable/install.html) for more details.
|
||||||
|
|
||||||
|
Enter the Python environment:
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install diffsynth
|
conda activate DiffSynthStudio
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage (in Python code)
|
## Usage (in Python code)
|
||||||
|
|
||||||
The Python examples are in [`examples`](./examples/). We provide an overview here.
|
The Python examples are in [`examples`](./examples/). We provide an overview here.
|
||||||
|
|
||||||
### Download Models
|
### Long Video Synthesis
|
||||||
|
|
||||||
Download the pre-set models. Model IDs can be found in [config file](/diffsynth/configs/model_config.py).
|
We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
|
||||||
|
|
||||||
```python
|
|
||||||
from diffsynth import download_models
|
|
||||||
|
|
||||||
download_models(["FLUX.1-dev", "Kolors"])
|
|
||||||
```
|
|
||||||
|
|
||||||
Download your own models.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from diffsynth.models.downloader import download_from_huggingface, download_from_modelscope
|
|
||||||
|
|
||||||
# From Modelscope (recommended)
|
|
||||||
download_from_modelscope("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.bin", "models/kolors/Kolors/vae")
|
|
||||||
# From Huggingface
|
|
||||||
download_from_huggingface("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp16.safetensors", "models/kolors/Kolors/vae")
|
|
||||||
```
|
|
||||||
|
|
||||||
### Video Synthesis
|
|
||||||
|
|
||||||
#### Text-to-video using CogVideoX-5B
|
|
||||||
|
|
||||||
CogVideoX-5B is released by ZhiPu. We provide an improved pipeline, supporting text-to-video, video editing, self-upscaling and video interpolation. [`examples/video_synthesis`](./examples/video_synthesis/)
|
|
||||||
|
|
||||||
The video on the left is generated using the original text-to-video pipeline, while the video on the right is the result after editing and frame interpolation.
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
|
|
||||||
|
|
||||||
#### Long Video Synthesis
|
|
||||||
|
|
||||||
We trained extended video synthesis models, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
|
|
||||||
|
|
||||||
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e
|
### Image Synthesis
|
||||||
|
|
||||||
#### Toon Shading
|
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/)
|
||||||
|
|
||||||
|
|512*512|1024*1024|2048*2048|4096*4096|
|
||||||
|
|-|-|-|-|
|
||||||
|
|||||
|
||||||
|
|
||||||
|
|1024*1024|2048*2048|
|
||||||
|
|-|-|
|
||||||
|
|||
|
||||||
|
|
||||||
|
### Toon Shading
|
||||||
|
|
||||||
Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
|
Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
|
||||||
|
|
||||||
@@ -171,60 +94,32 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
|
|||||||
|
|
||||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
|
||||||
|
|
||||||
#### Video Stylization
|
### Video Stylization
|
||||||
|
|
||||||
Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
|
Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
|
||||||
|
|
||||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||||
|
|
||||||
### Image Synthesis
|
### Chinese Models
|
||||||
|
|
||||||
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
|
Use Hunyuan-DiT to generate images with Chinese prompts. We also support LoRA fine-tuning of this model. [`examples/hunyuan_dit`](./examples/hunyuan_dit/)
|
||||||
|
|
||||||
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|
Prompt: 少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感
|
||||||
|
|
||||||
|FLUX|Stable Diffusion 3|
|
|1024x1024|2048x2048 (highres-fix)|
|
||||||
|-|-|
|
|-|-|
|
||||||
|||
|
|||
|
||||||
|
|
||||||
|Kolors|Hunyuan-DiT|
|
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
|
||||||
|-|-|
|
|
||||||
|||
|
|
||||||
|
|
||||||
|Stable Diffusion|Stable Diffusion XL|
|
|Without LoRA|With LoRA|
|
||||||
|-|-|
|
|-|-|
|
||||||
|||
|
|||
|
||||||
|
|
||||||
## Usage (in WebUI)
|
## Usage (in WebUI)
|
||||||
|
|
||||||
Create stunning images using the painter, with assistance from AI!
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/95265d21-cdd6-4125-a7cb-9fbcf6ceb7b0
|
|
||||||
|
|
||||||
**This video is not rendered in real-time.**
|
|
||||||
|
|
||||||
Before launching the WebUI, please download models to the folder `./models`. See [here](#download-models).
|
|
||||||
|
|
||||||
* `Gradio` version
|
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install gradio
|
python -m streamlit run DiffSynth_Studio.py
|
||||||
```
|
|
||||||
|
|
||||||
```
|
|
||||||
python apps/gradio/DiffSynth_Studio.py
|
|
||||||
```
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
* `Streamlit` version
|
|
||||||
|
|
||||||
```
|
|
||||||
pip install streamlit streamlit-drawable-canvas
|
|
||||||
```
|
|
||||||
|
|
||||||
```
|
|
||||||
python -m streamlit run apps/streamlit/DiffSynth_Studio.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
|
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
|
||||||
|
|||||||
@@ -1,252 +0,0 @@
|
|||||||
import gradio as gr
|
|
||||||
from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
|
|
||||||
import os, torch
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"model_config": {
|
|
||||||
"Stable Diffusion": {
|
|
||||||
"model_folder": "models/stable_diffusion",
|
|
||||||
"pipeline_class": SDImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
"height": 512,
|
|
||||||
"width": 512,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Stable Diffusion XL": {
|
|
||||||
"model_folder": "models/stable_diffusion_xl",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Stable Diffusion 3": {
|
|
||||||
"model_folder": "models/stable_diffusion_3",
|
|
||||||
"pipeline_class": SD3ImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Stable Diffusion XL Turbo": {
|
|
||||||
"model_folder": "models/stable_diffusion_xl_turbo",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"negative_prompt": "",
|
|
||||||
"cfg_scale": 1.0,
|
|
||||||
"num_inference_steps": 1,
|
|
||||||
"height": 512,
|
|
||||||
"width": 512,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Kolors": {
|
|
||||||
"model_folder": "models/kolors",
|
|
||||||
"pipeline_class": SDXLImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"HunyuanDiT": {
|
|
||||||
"model_folder": "models/HunyuanDiT",
|
|
||||||
"pipeline_class": HunyuanDiTImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 7.0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"FLUX": {
|
|
||||||
"model_folder": "models/FLUX",
|
|
||||||
"pipeline_class": FluxImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 1.0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"max_num_painter_layers": 8,
|
|
||||||
"max_num_model_cache": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_list(model_type):
|
|
||||||
if model_type is None:
|
|
||||||
return []
|
|
||||||
folder = config["model_config"][model_type]["model_folder"]
|
|
||||||
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
|
||||||
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
|
||||||
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
|
||||||
file_list = sorted(file_list)
|
|
||||||
return file_list
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_type, model_path):
|
|
||||||
global model_dict
|
|
||||||
model_key = f"{model_type}:{model_path}"
|
|
||||||
if model_key in model_dict:
|
|
||||||
return model_dict[model_key]
|
|
||||||
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
|
||||||
model_manager = ModelManager()
|
|
||||||
if model_type == "HunyuanDiT":
|
|
||||||
model_manager.load_models([
|
|
||||||
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
|
|
||||||
os.path.join(model_path, "mt5/pytorch_model.bin"),
|
|
||||||
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
|
||||||
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
|
||||||
])
|
|
||||||
elif model_type == "Kolors":
|
|
||||||
model_manager.load_models([
|
|
||||||
os.path.join(model_path, "text_encoder"),
|
|
||||||
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
|
|
||||||
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
|
|
||||||
])
|
|
||||||
elif model_type == "FLUX":
|
|
||||||
model_manager.torch_dtype = torch.bfloat16
|
|
||||||
file_list = [
|
|
||||||
os.path.join(model_path, "text_encoder/model.safetensors"),
|
|
||||||
os.path.join(model_path, "text_encoder_2"),
|
|
||||||
]
|
|
||||||
for file_name in os.listdir(model_path):
|
|
||||||
if file_name.endswith(".safetensors"):
|
|
||||||
file_list.append(os.path.join(model_path, file_name))
|
|
||||||
model_manager.load_models(file_list)
|
|
||||||
else:
|
|
||||||
model_manager.load_model(model_path)
|
|
||||||
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
|
||||||
while len(model_dict) + 1 > config["max_num_model_cache"]:
|
|
||||||
key = next(iter(model_dict.keys()))
|
|
||||||
model_manager_to_release, _ = model_dict[key]
|
|
||||||
model_manager_to_release.to("cpu")
|
|
||||||
del model_dict[key]
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model_dict[model_key] = model_manager, pipe
|
|
||||||
return model_manager, pipe
|
|
||||||
|
|
||||||
|
|
||||||
model_dict = {}
|
|
||||||
|
|
||||||
with gr.Blocks() as app:
|
|
||||||
gr.Markdown("# DiffSynth-Studio Painter")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=382, min_width=100):
|
|
||||||
|
|
||||||
with gr.Accordion(label="Model"):
|
|
||||||
model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
|
|
||||||
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
|
|
||||||
|
|
||||||
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
|
|
||||||
def model_type_to_model_path(model_type):
|
|
||||||
return gr.Dropdown(choices=load_model_list(model_type))
|
|
||||||
|
|
||||||
with gr.Accordion(label="Prompt"):
|
|
||||||
prompt = gr.Textbox(label="Prompt", lines=3)
|
|
||||||
negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
|
|
||||||
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
|
||||||
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
|
|
||||||
|
|
||||||
with gr.Accordion(label="Image"):
|
|
||||||
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
|
|
||||||
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
|
||||||
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
|
||||||
with gr.Column():
|
|
||||||
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
|
|
||||||
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
|
|
||||||
|
|
||||||
@gr.on(
|
|
||||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
|
||||||
outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
|
||||||
triggers=model_path.change
|
|
||||||
)
|
|
||||||
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
|
|
||||||
load_model(model_type, model_path)
|
|
||||||
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
|
|
||||||
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
|
|
||||||
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
|
|
||||||
height = config["model_config"][model_type]["default_parameters"].get("height", height)
|
|
||||||
width = config["model_config"][model_type]["default_parameters"].get("width", width)
|
|
||||||
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
|
|
||||||
|
|
||||||
|
|
||||||
with gr.Column(scale=618, min_width=100):
|
|
||||||
with gr.Accordion(label="Painter"):
|
|
||||||
enable_local_prompt_list = []
|
|
||||||
local_prompt_list = []
|
|
||||||
mask_scale_list = []
|
|
||||||
canvas_list = []
|
|
||||||
for painter_layer_id in range(config["max_num_painter_layers"]):
|
|
||||||
with gr.Tab(label=f"Layer {painter_layer_id}"):
|
|
||||||
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
|
|
||||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
|
||||||
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
|
|
||||||
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
|
|
||||||
brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
|
|
||||||
label="Painter", key=f"canvas_{painter_layer_id}")
|
|
||||||
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
|
|
||||||
def resize_canvas(height, width, canvas):
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
if h != height or width != w:
|
|
||||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
|
||||||
else:
|
|
||||||
return canvas
|
|
||||||
|
|
||||||
enable_local_prompt_list.append(enable_local_prompt)
|
|
||||||
local_prompt_list.append(local_prompt)
|
|
||||||
mask_scale_list.append(mask_scale)
|
|
||||||
canvas_list.append(canvas)
|
|
||||||
with gr.Accordion(label="Results"):
|
|
||||||
run_button = gr.Button(value="Generate", variant="primary")
|
|
||||||
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
output_to_painter_button = gr.Button(value="Set as painter's background")
|
|
||||||
with gr.Column():
|
|
||||||
output_to_input_button = gr.Button(value="Set as input image")
|
|
||||||
painter_background = gr.State(None)
|
|
||||||
input_background = gr.State(None)
|
|
||||||
@gr.on(
|
|
||||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
|
|
||||||
outputs=[output_image],
|
|
||||||
triggers=run_button.click
|
|
||||||
)
|
|
||||||
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
|
|
||||||
_, pipe = load_model(model_type, model_path)
|
|
||||||
input_params = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"negative_prompt": negative_prompt,
|
|
||||||
"cfg_scale": cfg_scale,
|
|
||||||
"num_inference_steps": num_inference_steps,
|
|
||||||
"height": height,
|
|
||||||
"width": width,
|
|
||||||
"progress_bar_cmd": progress.tqdm,
|
|
||||||
}
|
|
||||||
if isinstance(pipe, FluxImagePipeline):
|
|
||||||
input_params["embedded_guidance"] = embedded_guidance
|
|
||||||
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
|
|
||||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
|
||||||
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
|
||||||
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
|
|
||||||
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
|
|
||||||
)
|
|
||||||
local_prompts, masks, mask_scales = [], [], []
|
|
||||||
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
|
|
||||||
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
|
|
||||||
):
|
|
||||||
if enable_local_prompt:
|
|
||||||
local_prompts.append(local_prompt)
|
|
||||||
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
|
||||||
mask_scales.append(mask_scale)
|
|
||||||
input_params.update({
|
|
||||||
"local_prompts": local_prompts,
|
|
||||||
"masks": masks,
|
|
||||||
"mask_scales": mask_scales,
|
|
||||||
})
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
image = pipe(**input_params)
|
|
||||||
return image
|
|
||||||
|
|
||||||
@gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
|
||||||
def send_output_to_painter_background(output_image, *canvas_list):
|
|
||||||
for canvas in canvas_list:
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
canvas["background"] = output_image.resize((w, h))
|
|
||||||
return tuple(canvas_list)
|
|
||||||
app.launch()
|
|
||||||
@@ -1,390 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
|
||||||
import random
|
|
||||||
import json
|
|
||||||
import gradio as gr
|
|
||||||
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
|
|
||||||
|
|
||||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*")
|
|
||||||
example_json = 'data/examples/eligen/entity_control/ui_examples.json'
|
|
||||||
with open(example_json, 'r') as f:
|
|
||||||
examples = json.load(f)['examples']
|
|
||||||
|
|
||||||
for idx in range(len(examples)):
|
|
||||||
example_id = examples[idx]['example_id']
|
|
||||||
entity_prompts = examples[idx]['local_prompt_list']
|
|
||||||
examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
|
||||||
|
|
||||||
def create_canvas_data(background, masks):
|
|
||||||
if background.shape[-1] == 3:
|
|
||||||
background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
|
|
||||||
layers = []
|
|
||||||
for mask in masks:
|
|
||||||
if mask is not None:
|
|
||||||
mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
|
|
||||||
layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
|
|
||||||
layer[..., -1] = mask_single_channel
|
|
||||||
layers.append(layer)
|
|
||||||
else:
|
|
||||||
layers.append(np.zeros_like(background))
|
|
||||||
|
|
||||||
composite = background.copy()
|
|
||||||
for layer in layers:
|
|
||||||
if layer.size > 0:
|
|
||||||
composite = np.where(layer[..., -1:] > 0, layer, composite)
|
|
||||||
return {
|
|
||||||
"background": background,
|
|
||||||
"layers": layers,
|
|
||||||
"composite": composite,
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_example(load_example_button):
|
|
||||||
example_idx = int(load_example_button.split()[-1]) - 1
|
|
||||||
example = examples[example_idx]
|
|
||||||
result = [
|
|
||||||
50,
|
|
||||||
example["global_prompt"],
|
|
||||||
example["negative_prompt"],
|
|
||||||
example["seed"],
|
|
||||||
*example["local_prompt_list"],
|
|
||||||
]
|
|
||||||
num_entities = len(example["local_prompt_list"])
|
|
||||||
result += [""] * (config["max_num_painter_layers"] - num_entities)
|
|
||||||
masks = []
|
|
||||||
for mask in example["mask_lists"]:
|
|
||||||
mask_single_channel = np.array(mask.convert("L"))
|
|
||||||
masks.append(mask_single_channel)
|
|
||||||
for _ in range(config["max_num_painter_layers"] - len(masks)):
|
|
||||||
blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
|
|
||||||
masks.append(blank_mask)
|
|
||||||
background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
|
|
||||||
canvas_data_list = []
|
|
||||||
for mask in masks:
|
|
||||||
canvas_data = create_canvas_data(background, [mask])
|
|
||||||
canvas_data_list.append(canvas_data)
|
|
||||||
result.extend(canvas_data_list)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
|
|
||||||
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
|
|
||||||
print(f'save to {save_dir}')
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
for i, mask in enumerate(masks):
|
|
||||||
save_path = os.path.join(save_dir, f'{i}.png')
|
|
||||||
mask.save(save_path)
|
|
||||||
sample = {
|
|
||||||
"global_prompt": global_prompt,
|
|
||||||
"mask_prompts": mask_prompts,
|
|
||||||
"seed": seed,
|
|
||||||
}
|
|
||||||
with open(os.path.join(save_dir, f"prompts.json"), 'w') as f:
|
|
||||||
json.dump(sample, f, indent=4)
|
|
||||||
|
|
||||||
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
|
|
||||||
# Create a blank image for overlays
|
|
||||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
|
||||||
colors = [
|
|
||||||
(165, 238, 173, 80),
|
|
||||||
(76, 102, 221, 80),
|
|
||||||
(221, 160, 77, 80),
|
|
||||||
(204, 93, 71, 80),
|
|
||||||
(145, 187, 149, 80),
|
|
||||||
(134, 141, 172, 80),
|
|
||||||
(157, 137, 109, 80),
|
|
||||||
(153, 104, 95, 80),
|
|
||||||
(165, 238, 173, 80),
|
|
||||||
(76, 102, 221, 80),
|
|
||||||
(221, 160, 77, 80),
|
|
||||||
(204, 93, 71, 80),
|
|
||||||
(145, 187, 149, 80),
|
|
||||||
(134, 141, 172, 80),
|
|
||||||
(157, 137, 109, 80),
|
|
||||||
(153, 104, 95, 80),
|
|
||||||
]
|
|
||||||
# Generate random colors for each mask
|
|
||||||
if use_random_colors:
|
|
||||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
|
||||||
# Font settings
|
|
||||||
try:
|
|
||||||
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
|
||||||
except IOError:
|
|
||||||
font = ImageFont.load_default(font_size)
|
|
||||||
# Overlay each mask onto the overlay image
|
|
||||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
|
||||||
if mask is None:
|
|
||||||
continue
|
|
||||||
# Convert mask to RGBA mode
|
|
||||||
mask_rgba = mask.convert('RGBA')
|
|
||||||
mask_data = mask_rgba.getdata()
|
|
||||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
|
||||||
mask_rgba.putdata(new_data)
|
|
||||||
# Draw the mask prompt text on the mask
|
|
||||||
draw = ImageDraw.Draw(mask_rgba)
|
|
||||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
|
||||||
if mask_bbox is None:
|
|
||||||
continue
|
|
||||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
|
||||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
|
||||||
# Alpha composite the overlay with this mask
|
|
||||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
|
||||||
# Composite the overlay onto the original image
|
|
||||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
|
||||||
return result
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"model_config": {
|
|
||||||
"FLUX": {
|
|
||||||
"model_folder": "models/FLUX",
|
|
||||||
"pipeline_class": FluxImagePipeline,
|
|
||||||
"default_parameters": {
|
|
||||||
"cfg_scale": 3.0,
|
|
||||||
"embedded_guidance": 3.5,
|
|
||||||
"num_inference_steps": 30,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"max_num_painter_layers": 8,
|
|
||||||
"max_num_model_cache": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
model_dict = {}
|
|
||||||
|
|
||||||
def load_model(model_type='FLUX', model_path='FLUX.1-dev'):
|
|
||||||
global model_dict
|
|
||||||
model_key = f"{model_type}:{model_path}"
|
|
||||||
if model_key in model_dict:
|
|
||||||
return model_dict[model_key]
|
|
||||||
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
|
||||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
|
||||||
model_manager.load_lora(
|
|
||||||
download_customized_models(
|
|
||||||
model_id="DiffSynth-Studio/Eligen",
|
|
||||||
origin_file_path="model_bf16.safetensors",
|
|
||||||
local_dir="models/lora/entity_control",
|
|
||||||
),
|
|
||||||
lora_alpha=1,
|
|
||||||
)
|
|
||||||
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
|
||||||
model_dict[model_key] = model_manager, pipe
|
|
||||||
return model_manager, pipe
|
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks() as app:
|
|
||||||
gr.Markdown(
|
|
||||||
"""## EliGen: Entity-Level Controllable Text-to-Image Model
|
|
||||||
1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
|
|
||||||
2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
|
|
||||||
3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
|
|
||||||
4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
|
|
||||||
main_interface = gr.Column(visible=False)
|
|
||||||
|
|
||||||
def initialize_model():
|
|
||||||
try:
|
|
||||||
load_model()
|
|
||||||
return {
|
|
||||||
loading_status: gr.update(value="Model loaded successfully!", visible=False),
|
|
||||||
main_interface: gr.update(visible=True),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
print(f'Failed to load model with error: {e}')
|
|
||||||
return {
|
|
||||||
loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
|
|
||||||
main_interface: gr.update(visible=True),
|
|
||||||
}
|
|
||||||
|
|
||||||
app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
|
|
||||||
|
|
||||||
with main_interface:
|
|
||||||
with gr.Row():
|
|
||||||
local_prompt_list = []
|
|
||||||
canvas_list = []
|
|
||||||
random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
|
|
||||||
with gr.Column(scale=382, min_width=100):
|
|
||||||
model_type = gr.State('FLUX')
|
|
||||||
model_path = gr.State('FLUX.1-dev')
|
|
||||||
with gr.Accordion(label="Global prompt"):
|
|
||||||
prompt = gr.Textbox(label="Global Prompt", lines=3)
|
|
||||||
negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur,", lines=3)
|
|
||||||
with gr.Accordion(label="Inference Options", open=True):
|
|
||||||
seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
|
|
||||||
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
|
|
||||||
cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=3.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
|
||||||
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale")
|
|
||||||
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
|
||||||
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
|
||||||
with gr.Accordion(label="Inpaint Input Image", open=False):
|
|
||||||
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
|
|
||||||
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
reset_input_button = gr.Button(value="Reset Inpaint Input")
|
|
||||||
send_input_to_painter = gr.Button(value="Set as painter's background")
|
|
||||||
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
|
|
||||||
def reset_input_image(input_image):
|
|
||||||
return None
|
|
||||||
|
|
||||||
with gr.Column(scale=618, min_width=100):
|
|
||||||
with gr.Accordion(label="Entity Painter"):
|
|
||||||
for painter_layer_id in range(config["max_num_painter_layers"]):
|
|
||||||
with gr.Tab(label=f"Entity {painter_layer_id}"):
|
|
||||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
|
||||||
canvas = gr.ImageEditor(
|
|
||||||
canvas_size=(512, 512),
|
|
||||||
sources=None,
|
|
||||||
layers=False,
|
|
||||||
interactive=True,
|
|
||||||
image_mode="RGBA",
|
|
||||||
brush=gr.Brush(
|
|
||||||
default_size=50,
|
|
||||||
default_color="#000000",
|
|
||||||
colors=["#000000"],
|
|
||||||
),
|
|
||||||
label="Entity Mask Painter",
|
|
||||||
key=f"canvas_{painter_layer_id}",
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
)
|
|
||||||
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
|
|
||||||
def resize_canvas(height, width, canvas):
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
if h != height or width != w:
|
|
||||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
|
||||||
else:
|
|
||||||
return canvas
|
|
||||||
local_prompt_list.append(local_prompt)
|
|
||||||
canvas_list.append(canvas)
|
|
||||||
with gr.Accordion(label="Results"):
|
|
||||||
run_button = gr.Button(value="Generate", variant="primary")
|
|
||||||
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
output_to_painter_button = gr.Button(value="Set as painter's background")
|
|
||||||
with gr.Column():
|
|
||||||
return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
|
|
||||||
output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
|
|
||||||
real_output = gr.State(None)
|
|
||||||
mask_out = gr.State(None)
|
|
||||||
|
|
||||||
@gr.on(
|
|
||||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
|
|
||||||
outputs=[output_image, real_output, mask_out],
|
|
||||||
triggers=run_button.click
|
|
||||||
)
|
|
||||||
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
|
|
||||||
_, pipe = load_model(model_type, model_path)
|
|
||||||
input_params = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"negative_prompt": negative_prompt,
|
|
||||||
"cfg_scale": cfg_scale,
|
|
||||||
"num_inference_steps": num_inference_steps,
|
|
||||||
"height": height,
|
|
||||||
"width": width,
|
|
||||||
"progress_bar_cmd": progress.tqdm,
|
|
||||||
}
|
|
||||||
if isinstance(pipe, FluxImagePipeline):
|
|
||||||
input_params["embedded_guidance"] = embedded_guidance
|
|
||||||
if input_image is not None:
|
|
||||||
input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
|
|
||||||
input_params["enable_eligen_inpaint"] = True
|
|
||||||
|
|
||||||
local_prompt_list, canvas_list = (
|
|
||||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
|
||||||
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
|
||||||
)
|
|
||||||
local_prompts, masks = [], []
|
|
||||||
for local_prompt, canvas in zip(local_prompt_list, canvas_list):
|
|
||||||
if isinstance(local_prompt, str) and len(local_prompt) > 0:
|
|
||||||
local_prompts.append(local_prompt)
|
|
||||||
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
|
||||||
entity_masks = None if len(masks) == 0 else masks
|
|
||||||
entity_prompts = None if len(local_prompts) == 0 else local_prompts
|
|
||||||
input_params.update({
|
|
||||||
"eligen_entity_prompts": entity_prompts,
|
|
||||||
"eligen_entity_masks": entity_masks,
|
|
||||||
})
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
# save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
|
|
||||||
image = pipe(**input_params)
|
|
||||||
masks = [mask.resize(image.size) for mask in masks]
|
|
||||||
image_with_mask = visualize_masks(image, masks, local_prompts)
|
|
||||||
|
|
||||||
real_output = gr.State(image)
|
|
||||||
mask_out = gr.State(image_with_mask)
|
|
||||||
|
|
||||||
if return_with_mask:
|
|
||||||
return image_with_mask, real_output, mask_out
|
|
||||||
return image, real_output, mask_out
|
|
||||||
|
|
||||||
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
|
|
||||||
def send_input_to_painter_background(input_image, *canvas_list):
|
|
||||||
if input_image is None:
|
|
||||||
return tuple(canvas_list)
|
|
||||||
for canvas in canvas_list:
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
canvas["background"] = input_image.resize((w, h))
|
|
||||||
return tuple(canvas_list)
|
|
||||||
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
|
||||||
def send_output_to_painter_background(real_output, *canvas_list):
|
|
||||||
if real_output is None:
|
|
||||||
return tuple(canvas_list)
|
|
||||||
for canvas in canvas_list:
|
|
||||||
h, w = canvas["background"].shape[:2]
|
|
||||||
canvas["background"] = real_output.value.resize((w, h))
|
|
||||||
return tuple(canvas_list)
|
|
||||||
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
|
|
||||||
def show_output(return_with_mask, real_output, mask_out):
|
|
||||||
if return_with_mask:
|
|
||||||
return mask_out.value
|
|
||||||
else:
|
|
||||||
return real_output.value
|
|
||||||
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
|
|
||||||
def send_output_to_pipe_input(real_output):
|
|
||||||
return real_output.value
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
gr.Markdown("## Examples")
|
|
||||||
for i in range(0, len(examples), 2):
|
|
||||||
with gr.Row():
|
|
||||||
if i < len(examples):
|
|
||||||
example = examples[i]
|
|
||||||
with gr.Column():
|
|
||||||
example_image = gr.Image(
|
|
||||||
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
|
||||||
label=example["description"],
|
|
||||||
interactive=False,
|
|
||||||
width=1024,
|
|
||||||
height=512
|
|
||||||
)
|
|
||||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
|
||||||
load_example_button.click(
|
|
||||||
load_example,
|
|
||||||
inputs=[load_example_button],
|
|
||||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
|
||||||
)
|
|
||||||
|
|
||||||
if i + 1 < len(examples):
|
|
||||||
example = examples[i + 1]
|
|
||||||
with gr.Column():
|
|
||||||
example_image = gr.Image(
|
|
||||||
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
|
||||||
label=example["description"],
|
|
||||||
interactive=False,
|
|
||||||
width=1024,
|
|
||||||
height=512
|
|
||||||
)
|
|
||||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
|
||||||
load_example_button.click(
|
|
||||||
load_example,
|
|
||||||
inputs=[load_example_button],
|
|
||||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
|
||||||
)
|
|
||||||
app.config["show_progress"] = "hidden"
|
|
||||||
app.launch()
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from .data import *
|
from .data import *
|
||||||
from .models import *
|
from .models import *
|
||||||
from .prompters import *
|
from .prompts import *
|
||||||
from .schedulers import *
|
from .schedulers import *
|
||||||
from .pipelines import *
|
from .pipelines import *
|
||||||
from .controlnets import *
|
from .controlnets import *
|
||||||
|
|||||||
@@ -1,736 +0,0 @@
|
|||||||
from typing_extensions import Literal, TypeAlias
|
|
||||||
|
|
||||||
from ..models.sd_text_encoder import SDTextEncoder
|
|
||||||
from ..models.sd_unet import SDUNet
|
|
||||||
from ..models.sd_vae_encoder import SDVAEEncoder
|
|
||||||
from ..models.sd_vae_decoder import SDVAEDecoder
|
|
||||||
|
|
||||||
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|
||||||
from ..models.sdxl_unet import SDXLUNet
|
|
||||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
|
||||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
|
||||||
|
|
||||||
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
|
||||||
from ..models.sd3_dit import SD3DiT
|
|
||||||
from ..models.sd3_vae_decoder import SD3VAEDecoder
|
|
||||||
from ..models.sd3_vae_encoder import SD3VAEEncoder
|
|
||||||
|
|
||||||
from ..models.sd_controlnet import SDControlNet
|
|
||||||
from ..models.sdxl_controlnet import SDXLControlNetUnion
|
|
||||||
|
|
||||||
from ..models.sd_motion import SDMotionModel
|
|
||||||
from ..models.sdxl_motion import SDXLMotionModel
|
|
||||||
|
|
||||||
from ..models.svd_image_encoder import SVDImageEncoder
|
|
||||||
from ..models.svd_unet import SVDUNet
|
|
||||||
from ..models.svd_vae_decoder import SVDVAEDecoder
|
|
||||||
from ..models.svd_vae_encoder import SVDVAEEncoder
|
|
||||||
|
|
||||||
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
|
||||||
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
|
||||||
|
|
||||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
|
||||||
from ..models.hunyuan_dit import HunyuanDiT
|
|
||||||
|
|
||||||
from ..models.flux_dit import FluxDiT
|
|
||||||
from ..models.flux_text_encoder import FluxTextEncoder2
|
|
||||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
|
||||||
from ..models.flux_controlnet import FluxControlNet
|
|
||||||
from ..models.flux_ipadapter import FluxIpAdapter
|
|
||||||
|
|
||||||
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
|
||||||
from ..models.cog_dit import CogDiT
|
|
||||||
|
|
||||||
from ..models.omnigen import OmniGenTransformer
|
|
||||||
|
|
||||||
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
|
||||||
from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
|
||||||
|
|
||||||
from ..extensions.RIFE import IFNet
|
|
||||||
from ..extensions.ESRGAN import RRDBNet
|
|
||||||
|
|
||||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
|
||||||
|
|
||||||
|
|
||||||
model_loader_configs = [
|
|
||||||
# These configs are provided for detecting model type automatically.
|
|
||||||
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
|
||||||
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
|
|
||||||
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
|
|
||||||
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
|
|
||||||
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
|
|
||||||
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
|
|
||||||
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
|
|
||||||
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
|
|
||||||
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
|
|
||||||
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
|
|
||||||
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
|
|
||||||
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
|
|
||||||
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
|
|
||||||
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
|
|
||||||
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
|
|
||||||
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
|
|
||||||
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
|
|
||||||
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
|
|
||||||
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
|
|
||||||
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
|
|
||||||
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
|
|
||||||
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
|
||||||
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
|
||||||
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
|
|
||||||
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
|
||||||
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
|
|
||||||
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
|
|
||||||
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
|
|
||||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
|
||||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
|
||||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
|
||||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
|
||||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
|
||||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
|
||||||
(None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
|
|
||||||
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
|
||||||
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
|
||||||
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
|
||||||
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
|
||||||
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
|
|
||||||
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
|
||||||
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
|
||||||
]
|
|
||||||
huggingface_model_loader_configs = [
|
|
||||||
# These configs are provided for detecting model type automatically.
|
|
||||||
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
|
||||||
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
|
||||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
|
||||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
|
||||||
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
|
|
||||||
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
|
||||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
|
||||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
|
||||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
|
||||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder")
|
|
||||||
]
|
|
||||||
patch_model_loader_configs = [
|
|
||||||
# These configs are provided for detecting model type automatically.
|
|
||||||
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
|
||||||
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
|
||||||
]
|
|
||||||
|
|
||||||
preset_models_on_huggingface = {
|
|
||||||
"HunyuanDiT": [
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
|
||||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
|
||||||
],
|
|
||||||
"stable-video-diffusion-img2vid-xt": [
|
|
||||||
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
"ExVideo-SVD-128f-v1": [
|
|
||||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion
|
|
||||||
"StableDiffusion_v15": [
|
|
||||||
("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"DreamShaper_8": [
|
|
||||||
("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
# Textual Inversion
|
|
||||||
"TextualInversion_VeryBadImageNegative_v1.3": [
|
|
||||||
("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion XL
|
|
||||||
"StableDiffusionXL_v1": [
|
|
||||||
("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"BluePencilXL_v200": [
|
|
||||||
("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"StableDiffusionXL_Turbo": [
|
|
||||||
("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion 3
|
|
||||||
"StableDiffusion3": [
|
|
||||||
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
"StableDiffusion3_without_T5": [
|
|
||||||
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
# ControlNet
|
|
||||||
"ControlNet_v11f1p_sd15_depth": [
|
|
||||||
("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
|
||||||
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_softedge": [
|
|
||||||
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
|
||||||
("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11f1e_sd15_tile": [
|
|
||||||
("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_lineart": [
|
|
||||||
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
|
||||||
("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
|
|
||||||
("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_union_sdxl_promax": [
|
|
||||||
("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
|
||||||
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
||||||
],
|
|
||||||
# AnimateDiff
|
|
||||||
"AnimateDiff_v2": [
|
|
||||||
("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
"AnimateDiff_xl_beta": [
|
|
||||||
("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
|
|
||||||
# Qwen Prompt
|
|
||||||
"QwenPrompt": [
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
],
|
|
||||||
# Beautiful Prompt
|
|
||||||
"BeautifulPrompt": [
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
],
|
|
||||||
# Omost prompt
|
|
||||||
"OmostPrompt":[
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
],
|
|
||||||
# Translator
|
|
||||||
"opus-mt-zh-en": [
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
],
|
|
||||||
# IP-Adapter
|
|
||||||
"IP-Adapter-SD": [
|
|
||||||
("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
|
||||||
("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"IP-Adapter-SDXL": [
|
|
||||||
("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
|
||||||
("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"SDXL-vae-fp16-fix": [
|
|
||||||
("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
|
||||||
],
|
|
||||||
# Kolors
|
|
||||||
"Kolors": [
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
|
||||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
|
||||||
],
|
|
||||||
# FLUX
|
|
||||||
"FLUX.1-dev": [
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
],
|
|
||||||
"InstantX/FLUX.1-dev-IP-Adapter": {
|
|
||||||
"file_list": [
|
|
||||||
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
|
||||||
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
|
||||||
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
|
||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# RIFE
|
|
||||||
"RIFE": [
|
|
||||||
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
|
|
||||||
],
|
|
||||||
# CogVideo
|
|
||||||
"CogVideoX-5B": [
|
|
||||||
("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion 3.5
|
|
||||||
"StableDiffusion3.5-large": [
|
|
||||||
("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
|
||||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
preset_models_on_modelscope = {
|
|
||||||
# Hunyuan DiT
|
|
||||||
"HunyuanDiT": [
|
|
||||||
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
|
||||||
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
|
||||||
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
|
||||||
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
|
||||||
],
|
|
||||||
# Stable Video Diffusion
|
|
||||||
"stable-video-diffusion-img2vid-xt": [
|
|
||||||
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
# ExVideo
|
|
||||||
"ExVideo-SVD-128f-v1": [
|
|
||||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
|
||||||
],
|
|
||||||
"ExVideo-CogVideoX-LoRA-129f-v1": [
|
|
||||||
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion
|
|
||||||
"StableDiffusion_v15": [
|
|
||||||
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"DreamShaper_8": [
|
|
||||||
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"AingDiffusion_v12": [
|
|
||||||
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"Flat2DAnimerge_v45Sharp": [
|
|
||||||
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
|
||||||
],
|
|
||||||
# Textual Inversion
|
|
||||||
"TextualInversion_VeryBadImageNegative_v1.3": [
|
|
||||||
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion XL
|
|
||||||
"StableDiffusionXL_v1": [
|
|
||||||
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"BluePencilXL_v200": [
|
|
||||||
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
"StableDiffusionXL_Turbo": [
|
|
||||||
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
|
||||||
],
|
|
||||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
|
|
||||||
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
|
|
||||||
],
|
|
||||||
# Stable Diffusion 3
|
|
||||||
"StableDiffusion3": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
"StableDiffusion3_without_T5": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
|
||||||
],
|
|
||||||
# ControlNet
|
|
||||||
"ControlNet_v11f1p_sd15_depth": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
|
||||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_softedge": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
|
||||||
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_v11f1e_sd15_tile": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
|
||||||
],
|
|
||||||
"ControlNet_v11p_sd15_lineart": [
|
|
||||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
|
||||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
|
||||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
|
||||||
],
|
|
||||||
"ControlNet_union_sdxl_promax": [
|
|
||||||
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
|
||||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
|
||||||
],
|
|
||||||
"Annotators:Depth": [
|
|
||||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
|
||||||
],
|
|
||||||
"Annotators:Softedge": [
|
|
||||||
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
|
|
||||||
],
|
|
||||||
"Annotators:Lineart": [
|
|
||||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
|
||||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
|
|
||||||
],
|
|
||||||
"Annotators:Normal": [
|
|
||||||
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
|
|
||||||
],
|
|
||||||
"Annotators:Openpose": [
|
|
||||||
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
|
|
||||||
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
|
|
||||||
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
|
|
||||||
],
|
|
||||||
# AnimateDiff
|
|
||||||
"AnimateDiff_v2": [
|
|
||||||
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
"AnimateDiff_xl_beta": [
|
|
||||||
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
|
||||||
],
|
|
||||||
# RIFE
|
|
||||||
"RIFE": [
|
|
||||||
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
|
||||||
],
|
|
||||||
# Qwen Prompt
|
|
||||||
"QwenPrompt": {
|
|
||||||
"file_list": [
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/QwenPrompt/qwen2-1.5b-instruct",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# Beautiful Prompt
|
|
||||||
"BeautifulPrompt": {
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# Omost prompt
|
|
||||||
"OmostPrompt": {
|
|
||||||
"file_list": [
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# Translator
|
|
||||||
"opus-mt-zh-en": {
|
|
||||||
"file_list": [
|
|
||||||
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/translator/opus-mt-zh-en",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# IP-Adapter
|
|
||||||
"IP-Adapter-SD": [
|
|
||||||
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
|
||||||
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
|
||||||
],
|
|
||||||
"IP-Adapter-SDXL": [
|
|
||||||
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
|
||||||
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
|
||||||
],
|
|
||||||
# Kolors
|
|
||||||
"Kolors": {
|
|
||||||
"file_list": [
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
|
||||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
|
||||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/kolors/Kolors/text_encoder",
|
|
||||||
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
|
||||||
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"SDXL-vae-fp16-fix": [
|
|
||||||
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
|
||||||
],
|
|
||||||
# FLUX
|
|
||||||
"FLUX.1-dev": {
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
|
||||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
|
||||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
|
||||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"FLUX.1-schnell": {
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
|
||||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
|
||||||
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
|
||||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
|
||||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
|
||||||
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
|
|
||||||
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
|
|
||||||
],
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Depth": [
|
|
||||||
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
|
|
||||||
],
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
|
|
||||||
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
|
|
||||||
],
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
|
|
||||||
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
|
|
||||||
],
|
|
||||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
|
|
||||||
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
|
|
||||||
],
|
|
||||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
|
|
||||||
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
|
|
||||||
],
|
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
|
|
||||||
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
|
|
||||||
],
|
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
|
|
||||||
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
|
|
||||||
],
|
|
||||||
"InstantX/FLUX.1-dev-IP-Adapter": {
|
|
||||||
"file_list": [
|
|
||||||
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
|
||||||
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
|
||||||
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
|
||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# ESRGAN
|
|
||||||
"ESRGAN_x4": [
|
|
||||||
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
|
||||||
],
|
|
||||||
# RIFE
|
|
||||||
"RIFE": [
|
|
||||||
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
|
||||||
],
|
|
||||||
# Omnigen
|
|
||||||
"OmniGen-v1": {
|
|
||||||
"file_list": [
|
|
||||||
("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
|
|
||||||
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
|
|
||||||
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
|
|
||||||
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
|
|
||||||
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
|
|
||||||
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
|
|
||||||
"models/OmniGen/OmniGen-v1/model.safetensors",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
# CogVideo
|
|
||||||
"CogVideoX-5B": {
|
|
||||||
"file_list": [
|
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
|
||||||
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/CogVideo/CogVideoX-5b/text_encoder",
|
|
||||||
"models/CogVideo/CogVideoX-5b/transformer",
|
|
||||||
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# Stable Diffusion 3.5
|
|
||||||
"StableDiffusion3.5-large": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
],
|
|
||||||
"StableDiffusion3.5-medium": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
],
|
|
||||||
"StableDiffusion3.5-large-turbo": [
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
|
||||||
],
|
|
||||||
"HunyuanVideo":{
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
|
||||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
|
||||||
"models/HunyuanVideo/text_encoder_2",
|
|
||||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
|
||||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"HunyuanVideo-fp8":{
|
|
||||||
"file_list": [
|
|
||||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
|
||||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
|
||||||
("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
|
|
||||||
],
|
|
||||||
"load_path": [
|
|
||||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
|
||||||
"models/HunyuanVideo/text_encoder_2",
|
|
||||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
|
||||||
"models/HunyuanVideo/transformers/model.fp8.safetensors"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
Preset_model_id: TypeAlias = Literal[
|
|
||||||
"HunyuanDiT",
|
|
||||||
"stable-video-diffusion-img2vid-xt",
|
|
||||||
"ExVideo-SVD-128f-v1",
|
|
||||||
"ExVideo-CogVideoX-LoRA-129f-v1",
|
|
||||||
"StableDiffusion_v15",
|
|
||||||
"DreamShaper_8",
|
|
||||||
"AingDiffusion_v12",
|
|
||||||
"Flat2DAnimerge_v45Sharp",
|
|
||||||
"TextualInversion_VeryBadImageNegative_v1.3",
|
|
||||||
"StableDiffusionXL_v1",
|
|
||||||
"BluePencilXL_v200",
|
|
||||||
"StableDiffusionXL_Turbo",
|
|
||||||
"ControlNet_v11f1p_sd15_depth",
|
|
||||||
"ControlNet_v11p_sd15_softedge",
|
|
||||||
"ControlNet_v11f1e_sd15_tile",
|
|
||||||
"ControlNet_v11p_sd15_lineart",
|
|
||||||
"AnimateDiff_v2",
|
|
||||||
"AnimateDiff_xl_beta",
|
|
||||||
"RIFE",
|
|
||||||
"BeautifulPrompt",
|
|
||||||
"opus-mt-zh-en",
|
|
||||||
"IP-Adapter-SD",
|
|
||||||
"IP-Adapter-SDXL",
|
|
||||||
"StableDiffusion3",
|
|
||||||
"StableDiffusion3_without_T5",
|
|
||||||
"Kolors",
|
|
||||||
"SDXL-vae-fp16-fix",
|
|
||||||
"ControlNet_union_sdxl_promax",
|
|
||||||
"FLUX.1-dev",
|
|
||||||
"FLUX.1-schnell",
|
|
||||||
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Depth",
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
|
||||||
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
|
||||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
|
||||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
|
||||||
"InstantX/FLUX.1-dev-IP-Adapter",
|
|
||||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
|
||||||
"QwenPrompt",
|
|
||||||
"OmostPrompt",
|
|
||||||
"ESRGAN_x4",
|
|
||||||
"RIFE",
|
|
||||||
"OmniGen-v1",
|
|
||||||
"CogVideoX-5B",
|
|
||||||
"Annotators:Depth",
|
|
||||||
"Annotators:Softedge",
|
|
||||||
"Annotators:Lineart",
|
|
||||||
"Annotators:Normal",
|
|
||||||
"Annotators:Openpose",
|
|
||||||
"StableDiffusion3.5-large",
|
|
||||||
"StableDiffusion3.5-medium",
|
|
||||||
"HunyuanVideo",
|
|
||||||
"HunyuanVideo-fp8",
|
|
||||||
]
|
|
||||||
@@ -1,2 +1,2 @@
|
|||||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
|
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
|
||||||
from .processors import Annotator
|
from .processors import Annotator
|
||||||
|
|||||||
@@ -4,11 +4,10 @@ from .processors import Processor_id
|
|||||||
|
|
||||||
|
|
||||||
class ControlNetConfigUnit:
|
class ControlNetConfigUnit:
|
||||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
|
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
||||||
self.processor_id = processor_id
|
self.processor_id = processor_id
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.skip_processor = skip_processor
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetUnit:
|
class ControlNetUnit:
|
||||||
@@ -24,16 +23,6 @@ class MultiControlNetManager:
|
|||||||
self.models = [unit.model for unit in controlnet_units]
|
self.models = [unit.model for unit in controlnet_units]
|
||||||
self.scales = [unit.scale for unit in controlnet_units]
|
self.scales = [unit.scale for unit in controlnet_units]
|
||||||
|
|
||||||
def cpu(self):
|
|
||||||
for model in self.models:
|
|
||||||
model.cpu()
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
for model in self.models:
|
|
||||||
model.to(device)
|
|
||||||
for processor in self.processors:
|
|
||||||
processor.to(device)
|
|
||||||
|
|
||||||
def process_image(self, image, processor_id=None):
|
def process_image(self, image, processor_id=None):
|
||||||
if processor_id is None:
|
if processor_id is None:
|
||||||
processed_image = [processor(image) for processor in self.processors]
|
processed_image = [processor(image) for processor in self.processors]
|
||||||
@@ -48,14 +37,13 @@ class MultiControlNetManager:
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
sample, timestep, encoder_hidden_states, conditionings,
|
sample, timestep, encoder_hidden_states, conditionings,
|
||||||
tiled=False, tile_size=64, tile_stride=32, **kwargs
|
tiled=False, tile_size=64, tile_stride=32
|
||||||
):
|
):
|
||||||
res_stack = None
|
res_stack = None
|
||||||
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
|
||||||
res_stack_ = model(
|
res_stack_ = model(
|
||||||
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
|
sample, timestep, encoder_hidden_states, conditioning,
|
||||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||||
processor_id=processor.processor_id
|
|
||||||
)
|
)
|
||||||
res_stack_ = [res * scale for res in res_stack_]
|
res_stack_ = [res * scale for res in res_stack_]
|
||||||
if res_stack is None:
|
if res_stack is None:
|
||||||
@@ -63,29 +51,3 @@ class MultiControlNetManager:
|
|||||||
else:
|
else:
|
||||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||||
return res_stack
|
return res_stack
|
||||||
|
|
||||||
|
|
||||||
class FluxMultiControlNetManager(MultiControlNetManager):
|
|
||||||
def __init__(self, controlnet_units=[]):
|
|
||||||
super().__init__(controlnet_units=controlnet_units)
|
|
||||||
|
|
||||||
def process_image(self, image, processor_id=None):
|
|
||||||
if processor_id is None:
|
|
||||||
processed_image = [processor(image) for processor in self.processors]
|
|
||||||
else:
|
|
||||||
processed_image = [self.processors[processor_id](image)]
|
|
||||||
return processed_image
|
|
||||||
|
|
||||||
def __call__(self, conditionings, **kwargs):
|
|
||||||
res_stack, single_res_stack = None, None
|
|
||||||
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
|
||||||
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
|
|
||||||
res_stack_ = [res * scale for res in res_stack_]
|
|
||||||
single_res_stack_ = [res * scale for res in single_res_stack_]
|
|
||||||
if res_stack is None:
|
|
||||||
res_stack = res_stack_
|
|
||||||
single_res_stack = single_res_stack_
|
|
||||||
else:
|
|
||||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
|
||||||
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
|
|
||||||
return res_stack, single_res_stack
|
|
||||||
|
|||||||
@@ -3,47 +3,37 @@ import warnings
|
|||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
from controlnet_aux.processor import (
|
from controlnet_aux.processor import (
|
||||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
|
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
Processor_id: TypeAlias = Literal[
|
Processor_id: TypeAlias = Literal[
|
||||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
|
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
||||||
]
|
]
|
||||||
|
|
||||||
class Annotator:
|
class Annotator:
|
||||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
|
||||||
if not skip_processor:
|
if processor_id == "canny":
|
||||||
if processor_id == "canny":
|
self.processor = CannyDetector()
|
||||||
self.processor = CannyDetector()
|
elif processor_id == "depth":
|
||||||
elif processor_id == "depth":
|
self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
|
||||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
elif processor_id == "softedge":
|
||||||
elif processor_id == "softedge":
|
self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
|
||||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
elif processor_id == "lineart":
|
||||||
elif processor_id == "lineart":
|
self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
|
||||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
elif processor_id == "lineart_anime":
|
||||||
elif processor_id == "lineart_anime":
|
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
|
||||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
elif processor_id == "openpose":
|
||||||
elif processor_id == "openpose":
|
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
|
||||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
elif processor_id == "tile":
|
||||||
elif processor_id == "normal":
|
|
||||||
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
|
|
||||||
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
|
|
||||||
self.processor = None
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
|
||||||
else:
|
|
||||||
self.processor = None
|
self.processor = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||||
|
|
||||||
self.processor_id = processor_id
|
self.processor_id = processor_id
|
||||||
self.detect_resolution = detect_resolution
|
self.detect_resolution = detect_resolution
|
||||||
|
|
||||||
def to(self,device):
|
|
||||||
if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):
|
|
||||||
|
|
||||||
self.processor.model.to(device)
|
def __call__(self, image):
|
||||||
|
|
||||||
def __call__(self, image, mask=None):
|
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
if self.processor_id == "openpose":
|
if self.processor_id == "openpose":
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
|||||||
@@ -1,41 +0,0 @@
|
|||||||
import torch, os, torchvision
|
|
||||||
from torchvision import transforms
|
|
||||||
import pandas as pd
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextImageDataset(torch.utils.data.Dataset):
|
|
||||||
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
|
||||||
self.steps_per_epoch = steps_per_epoch
|
|
||||||
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
|
||||||
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
|
||||||
self.text = metadata["text"].to_list()
|
|
||||||
self.height = height
|
|
||||||
self.width = width
|
|
||||||
self.image_processor = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
|
||||||
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.5], [0.5]),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
|
||||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
|
||||||
text = self.text[data_id]
|
|
||||||
image = Image.open(self.path[data_id]).convert("RGB")
|
|
||||||
target_height, target_width = self.height, self.width
|
|
||||||
width, height = image.size
|
|
||||||
scale = max(target_width / width, target_height / height)
|
|
||||||
shape = [round(height*scale),round(width*scale)]
|
|
||||||
image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
|
|
||||||
image = self.image_processor(image)
|
|
||||||
return {"text": text, "image": image}
|
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.steps_per_epoch
|
|
||||||
@@ -41,7 +41,7 @@ class RRDB(torch.nn.Module):
|
|||||||
|
|
||||||
class RRDBNet(torch.nn.Module):
|
class RRDBNet(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
|
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
|
||||||
super(RRDBNet, self).__init__()
|
super(RRDBNet, self).__init__()
|
||||||
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
||||||
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
||||||
@@ -65,21 +65,6 @@ class RRDBNet(torch.nn.Module):
|
|||||||
feat = self.lrelu(self.conv_up2(feat))
|
feat = self.lrelu(self.conv_up2(feat))
|
||||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return RRDBNetStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class RRDBNetStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
return state_dict, {"upcast_to_float32": True}
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return state_dict, {"upcast_to_float32": True}
|
|
||||||
|
|
||||||
|
|
||||||
class ESRGAN(torch.nn.Module):
|
class ESRGAN(torch.nn.Module):
|
||||||
@@ -88,8 +73,12 @@ class ESRGAN(torch.nn.Module):
|
|||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager):
|
def from_pretrained(model_path):
|
||||||
return ESRGAN(model_manager.fetch_model("esrgan"))
|
model = RRDBNet()
|
||||||
|
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.eval()
|
||||||
|
return ESRGAN(model)
|
||||||
|
|
||||||
def process_image(self, image):
|
def process_image(self, image):
|
||||||
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
||||||
@@ -107,12 +96,6 @@ class ESRGAN(torch.nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
||||||
if not isinstance(images, list):
|
|
||||||
images = [images]
|
|
||||||
is_single_image = True
|
|
||||||
else:
|
|
||||||
is_single_image = False
|
|
||||||
|
|
||||||
# Preprocess
|
# Preprocess
|
||||||
input_tensor = self.process_images(images)
|
input_tensor = self.process_images(images)
|
||||||
|
|
||||||
@@ -132,6 +115,4 @@ class ESRGAN(torch.nn.Module):
|
|||||||
|
|
||||||
# To images
|
# To images
|
||||||
output_images = self.decode_images(output_tensor)
|
output_images = self.decode_images(output_tensor)
|
||||||
if is_single_image:
|
|
||||||
output_images = output_images[0]
|
|
||||||
return output_images
|
return output_images
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class IFBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class IFNet(nn.Module):
|
class IFNet(nn.Module):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self):
|
||||||
super(IFNet, self).__init__()
|
super(IFNet, self).__init__()
|
||||||
self.block0 = IFBlock(7+4, c=90)
|
self.block0 = IFBlock(7+4, c=90)
|
||||||
self.block1 = IFBlock(7+4, c=90)
|
self.block1 = IFBlock(7+4, c=90)
|
||||||
@@ -99,8 +99,7 @@ class IFNet(nn.Module):
|
|||||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||||
return flow_list, mask_list[2], merged
|
return flow_list, mask_list[2], merged
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return IFNetStateDictConverter()
|
return IFNetStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
@@ -113,7 +112,7 @@ class IFNetStateDictConverter:
|
|||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
def from_civitai(self, state_dict):
|
||||||
return self.from_diffusers(state_dict), {"upcast_to_float32": True}
|
return self.from_diffusers(state_dict)
|
||||||
|
|
||||||
|
|
||||||
class RIFEInterpolater:
|
class RIFEInterpolater:
|
||||||
@@ -125,7 +124,7 @@ class RIFEInterpolater:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager):
|
def from_model_manager(model_manager):
|
||||||
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
|
||||||
|
|
||||||
def process_image(self, image):
|
def process_image(self, image):
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
@@ -203,7 +202,7 @@ class RIFESmoother(RIFEInterpolater):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager):
|
def from_model_manager(model_manager):
|
||||||
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
return RIFESmoother(model_manager.RIFE, device=model_manager.device)
|
||||||
|
|
||||||
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
||||||
output_tensor = []
|
output_tensor = []
|
||||||
|
|||||||
@@ -1 +1,482 @@
|
|||||||
from .model_manager import *
|
import torch, os
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
from .sd_text_encoder import SDTextEncoder
|
||||||
|
from .sd_unet import SDUNet
|
||||||
|
from .sd_vae_encoder import SDVAEEncoder
|
||||||
|
from .sd_vae_decoder import SDVAEDecoder
|
||||||
|
from .sd_lora import SDLoRA
|
||||||
|
|
||||||
|
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||||
|
from .sdxl_unet import SDXLUNet
|
||||||
|
from .sdxl_vae_decoder import SDXLVAEDecoder
|
||||||
|
from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||||
|
|
||||||
|
from .sd_controlnet import SDControlNet
|
||||||
|
|
||||||
|
from .sd_motion import SDMotionModel
|
||||||
|
from .sdxl_motion import SDXLMotionModel
|
||||||
|
|
||||||
|
from .svd_image_encoder import SVDImageEncoder
|
||||||
|
from .svd_unet import SVDUNet
|
||||||
|
from .svd_vae_decoder import SVDVAEDecoder
|
||||||
|
from .svd_vae_encoder import SVDVAEEncoder
|
||||||
|
|
||||||
|
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||||
|
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||||
|
|
||||||
|
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||||
|
from .hunyuan_dit import HunyuanDiT
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
self.device = device
|
||||||
|
self.model = {}
|
||||||
|
self.model_path = {}
|
||||||
|
self.textual_inversion_dict = {}
|
||||||
|
|
||||||
|
def is_stable_video_diffusion(self, state_dict):
|
||||||
|
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def is_RIFE(self, state_dict):
|
||||||
|
param_name = "block_tea.convblock3.0.1.weight"
|
||||||
|
return param_name in state_dict or ("module." + param_name) in state_dict
|
||||||
|
|
||||||
|
def is_beautiful_prompt(self, state_dict):
|
||||||
|
param_name = "transformer.h.9.self_attention.query_key_value.weight"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def is_stabe_diffusion_xl(self, state_dict):
|
||||||
|
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def is_stable_diffusion(self, state_dict):
|
||||||
|
if self.is_stabe_diffusion_xl(state_dict):
|
||||||
|
return False
|
||||||
|
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def is_controlnet(self, state_dict):
|
||||||
|
param_name = "control_model.time_embed.0.weight"
|
||||||
|
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
|
||||||
|
return param_name in state_dict or param_name_2 in state_dict
|
||||||
|
|
||||||
|
def is_animatediff(self, state_dict):
|
||||||
|
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def is_animatediff_xl(self, state_dict):
|
||||||
|
param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def is_sd_lora(self, state_dict):
|
||||||
|
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def is_translator(self, state_dict):
|
||||||
|
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
|
||||||
|
return param_name in state_dict and len(state_dict) == 254
|
||||||
|
|
||||||
|
def is_ipadapter(self, state_dict):
|
||||||
|
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
|
||||||
|
|
||||||
|
def is_ipadapter_image_encoder(self, state_dict):
|
||||||
|
param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
|
||||||
|
return param_name in state_dict and len(state_dict) == 521
|
||||||
|
|
||||||
|
def is_ipadapter_xl(self, state_dict):
|
||||||
|
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
|
||||||
|
|
||||||
|
def is_ipadapter_xl_image_encoder(self, state_dict):
|
||||||
|
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
|
||||||
|
return param_name in state_dict and len(state_dict) == 777
|
||||||
|
|
||||||
|
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
|
||||||
|
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
|
||||||
|
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def is_hunyuan_dit(self, state_dict):
|
||||||
|
param_name = "final_layer.adaLN_modulation.1.weight"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def is_diffusers_vae(self, state_dict):
|
||||||
|
param_name = "quant_conv.weight"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def is_ExVideo_StableVideoDiffusion(self, state_dict):
|
||||||
|
param_name = "blocks.185.positional_embedding.embeddings"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
|
||||||
|
component_dict = {
|
||||||
|
"image_encoder": SVDImageEncoder,
|
||||||
|
"unet": SVDUNet,
|
||||||
|
"vae_decoder": SVDVAEDecoder,
|
||||||
|
"vae_encoder": SVDVAEEncoder,
|
||||||
|
}
|
||||||
|
if components is None:
|
||||||
|
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||||
|
for component in components:
|
||||||
|
if component == "unet":
|
||||||
|
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
|
||||||
|
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
|
||||||
|
else:
|
||||||
|
self.model[component] = component_dict[component]()
|
||||||
|
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||||
|
self.model[component].to(self.torch_dtype).to(self.device)
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
||||||
|
component_dict = {
|
||||||
|
"text_encoder": SDTextEncoder,
|
||||||
|
"unet": SDUNet,
|
||||||
|
"vae_decoder": SDVAEDecoder,
|
||||||
|
"vae_encoder": SDVAEEncoder,
|
||||||
|
"refiner": SDXLUNet,
|
||||||
|
}
|
||||||
|
if components is None:
|
||||||
|
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||||
|
for component in components:
|
||||||
|
if component == "text_encoder":
|
||||||
|
# Add additional token embeddings to text encoder
|
||||||
|
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
|
||||||
|
for keyword in self.textual_inversion_dict:
|
||||||
|
_, embeddings = self.textual_inversion_dict[keyword]
|
||||||
|
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
|
||||||
|
token_embeddings = torch.concat(token_embeddings, dim=0)
|
||||||
|
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
|
||||||
|
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
|
||||||
|
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||||
|
self.model[component].to(self.torch_dtype).to(self.device)
|
||||||
|
else:
|
||||||
|
self.model[component] = component_dict[component]()
|
||||||
|
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||||
|
self.model[component].to(self.torch_dtype).to(self.device)
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
|
||||||
|
component_dict = {
|
||||||
|
"text_encoder": SDXLTextEncoder,
|
||||||
|
"text_encoder_2": SDXLTextEncoder2,
|
||||||
|
"unet": SDXLUNet,
|
||||||
|
"vae_decoder": SDXLVAEDecoder,
|
||||||
|
"vae_encoder": SDXLVAEEncoder,
|
||||||
|
}
|
||||||
|
if components is None:
|
||||||
|
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
||||||
|
for component in components:
|
||||||
|
self.model[component] = component_dict[component]()
|
||||||
|
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||||
|
if component in ["vae_decoder", "vae_encoder"]:
|
||||||
|
# These two model will output nan when float16 is enabled.
|
||||||
|
# The precision problem happens in the last three resnet blocks.
|
||||||
|
# I do not know how to solve this problem.
|
||||||
|
self.model[component].to(torch.float32).to(self.device)
|
||||||
|
else:
|
||||||
|
self.model[component].to(self.torch_dtype).to(self.device)
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_controlnet(self, state_dict, file_path=""):
|
||||||
|
component = "controlnet"
|
||||||
|
if component not in self.model:
|
||||||
|
self.model[component] = []
|
||||||
|
self.model_path[component] = []
|
||||||
|
model = SDControlNet()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component].append(model)
|
||||||
|
self.model_path[component].append(file_path)
|
||||||
|
|
||||||
|
def load_animatediff(self, state_dict, file_path="", add_positional_conv=None):
|
||||||
|
component = "motion_modules"
|
||||||
|
model = SDMotionModel(add_positional_conv=add_positional_conv)
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_animatediff_xl(self, state_dict, file_path=""):
|
||||||
|
component = "motion_modules_xl"
|
||||||
|
model = SDXLMotionModel()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_beautiful_prompt(self, state_dict, file_path=""):
|
||||||
|
component = "beautiful_prompt"
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
model_folder = os.path.dirname(file_path)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
|
||||||
|
).to(self.device).eval()
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_RIFE(self, state_dict, file_path=""):
|
||||||
|
component = "RIFE"
|
||||||
|
from ..extensions.RIFE import IFNet
|
||||||
|
model = IFNet().eval()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||||
|
model.to(torch.float32).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_sd_lora(self, state_dict, alpha):
|
||||||
|
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
|
||||||
|
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
|
||||||
|
|
||||||
|
def load_translator(self, state_dict, file_path=""):
|
||||||
|
# This model is lightweight, we do not place it on GPU.
|
||||||
|
component = "translator"
|
||||||
|
from transformers import AutoModelForSeq2SeqLM
|
||||||
|
model_folder = os.path.dirname(file_path)
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_ipadapter(self, state_dict, file_path=""):
|
||||||
|
component = "ipadapter"
|
||||||
|
model = SDIpAdapter()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_ipadapter_image_encoder(self, state_dict, file_path=""):
|
||||||
|
component = "ipadapter_image_encoder"
|
||||||
|
model = IpAdapterCLIPImageEmbedder()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_ipadapter_xl(self, state_dict, file_path=""):
|
||||||
|
component = "ipadapter_xl"
|
||||||
|
model = SDXLIpAdapter()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
|
||||||
|
component = "ipadapter_xl_image_encoder"
|
||||||
|
model = IpAdapterXLCLIPImageEmbedder()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
|
||||||
|
component = "hunyuan_dit_clip_text_encoder"
|
||||||
|
model = HunyuanDiTCLIPTextEncoder()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
|
||||||
|
component = "hunyuan_dit_t5_text_encoder"
|
||||||
|
model = HunyuanDiTT5TextEncoder()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_hunyuan_dit(self, state_dict, file_path=""):
|
||||||
|
component = "hunyuan_dit"
|
||||||
|
model = HunyuanDiT()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_diffusers_vae(self, state_dict, file_path=""):
|
||||||
|
# TODO: detect SD and SDXL
|
||||||
|
component = "vae_encoder"
|
||||||
|
model = SDXLVAEEncoder()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
component = "vae_decoder"
|
||||||
|
model = SDXLVAEDecoder()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
|
||||||
|
unet_state_dict = self.model["unet"].state_dict()
|
||||||
|
self.model["unet"].to("cpu")
|
||||||
|
del self.model["unet"]
|
||||||
|
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
|
||||||
|
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
|
||||||
|
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
|
||||||
|
self.model["unet"].load_state_dict(state_dict, strict=False)
|
||||||
|
self.model["unet"].to(self.torch_dtype).to(self.device)
|
||||||
|
|
||||||
|
def search_for_embeddings(self, state_dict):
|
||||||
|
embeddings = []
|
||||||
|
for k in state_dict:
|
||||||
|
if isinstance(state_dict[k], torch.Tensor):
|
||||||
|
embeddings.append(state_dict[k])
|
||||||
|
elif isinstance(state_dict[k], dict):
|
||||||
|
embeddings += self.search_for_embeddings(state_dict[k])
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def load_textual_inversions(self, folder):
|
||||||
|
# Store additional tokens here
|
||||||
|
self.textual_inversion_dict = {}
|
||||||
|
|
||||||
|
# Load every textual inversion file
|
||||||
|
for file_name in os.listdir(folder):
|
||||||
|
if file_name.endswith(".txt"):
|
||||||
|
continue
|
||||||
|
keyword = os.path.splitext(file_name)[0]
|
||||||
|
state_dict = load_state_dict(os.path.join(folder, file_name))
|
||||||
|
|
||||||
|
# Search for embeddings
|
||||||
|
for embeddings in self.search_for_embeddings(state_dict):
|
||||||
|
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
|
||||||
|
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
|
||||||
|
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
||||||
|
break
|
||||||
|
|
||||||
|
def load_model(self, file_path, components=None, lora_alphas=[]):
|
||||||
|
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
||||||
|
if self.is_stable_video_diffusion(state_dict):
|
||||||
|
self.load_stable_video_diffusion(state_dict, file_path=file_path)
|
||||||
|
elif self.is_animatediff(state_dict):
|
||||||
|
self.load_animatediff(state_dict, file_path=file_path)
|
||||||
|
elif self.is_animatediff_xl(state_dict):
|
||||||
|
self.load_animatediff_xl(state_dict, file_path=file_path)
|
||||||
|
elif self.is_controlnet(state_dict):
|
||||||
|
self.load_controlnet(state_dict, file_path=file_path)
|
||||||
|
elif self.is_stabe_diffusion_xl(state_dict):
|
||||||
|
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
||||||
|
elif self.is_stable_diffusion(state_dict):
|
||||||
|
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
||||||
|
elif self.is_sd_lora(state_dict):
|
||||||
|
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
|
||||||
|
elif self.is_beautiful_prompt(state_dict):
|
||||||
|
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
||||||
|
elif self.is_RIFE(state_dict):
|
||||||
|
self.load_RIFE(state_dict, file_path=file_path)
|
||||||
|
elif self.is_translator(state_dict):
|
||||||
|
self.load_translator(state_dict, file_path=file_path)
|
||||||
|
elif self.is_ipadapter(state_dict):
|
||||||
|
self.load_ipadapter(state_dict, file_path=file_path)
|
||||||
|
elif self.is_ipadapter_image_encoder(state_dict):
|
||||||
|
self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
|
||||||
|
elif self.is_ipadapter_xl(state_dict):
|
||||||
|
self.load_ipadapter_xl(state_dict, file_path=file_path)
|
||||||
|
elif self.is_ipadapter_xl_image_encoder(state_dict):
|
||||||
|
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
|
||||||
|
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
|
||||||
|
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
|
||||||
|
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
|
||||||
|
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
|
||||||
|
elif self.is_hunyuan_dit(state_dict):
|
||||||
|
self.load_hunyuan_dit(state_dict, file_path=file_path)
|
||||||
|
elif self.is_diffusers_vae(state_dict):
|
||||||
|
self.load_diffusers_vae(state_dict, file_path=file_path)
|
||||||
|
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
|
||||||
|
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
|
||||||
|
|
||||||
|
def load_models(self, file_path_list, lora_alphas=[]):
|
||||||
|
for file_path in file_path_list:
|
||||||
|
self.load_model(file_path, lora_alphas=lora_alphas)
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
for component in self.model:
|
||||||
|
if isinstance(self.model[component], list):
|
||||||
|
for model in self.model[component]:
|
||||||
|
model.to(device)
|
||||||
|
else:
|
||||||
|
self.model[component].to(device)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def get_model_with_model_path(self, model_path):
|
||||||
|
for component in self.model_path:
|
||||||
|
if isinstance(self.model_path[component], str):
|
||||||
|
if os.path.samefile(self.model_path[component], model_path):
|
||||||
|
return self.model[component]
|
||||||
|
elif isinstance(self.model_path[component], list):
|
||||||
|
for i, model_path_ in enumerate(self.model_path[component]):
|
||||||
|
if os.path.samefile(model_path_, model_path):
|
||||||
|
return self.model[component][i]
|
||||||
|
raise ValueError(f"Please load model {model_path} before you use it.")
|
||||||
|
|
||||||
|
def __getattr__(self, __name):
|
||||||
|
if __name in self.model:
|
||||||
|
return self.model[__name]
|
||||||
|
else:
|
||||||
|
return super.__getattribute__(__name)
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(file_path, torch_dtype=None):
|
||||||
|
if file_path.endswith(".safetensors"):
|
||||||
|
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||||
|
else:
|
||||||
|
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||||
|
state_dict = {}
|
||||||
|
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
state_dict[k] = f.get_tensor(k)
|
||||||
|
if torch_dtype is not None:
|
||||||
|
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||||
|
state_dict = torch.load(file_path, map_location="cpu")
|
||||||
|
if torch_dtype is not None:
|
||||||
|
for i in state_dict:
|
||||||
|
if isinstance(state_dict[i], torch.Tensor):
|
||||||
|
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def search_parameter(param, state_dict):
|
||||||
|
for name, param_ in state_dict.items():
|
||||||
|
if param.numel() == param_.numel():
|
||||||
|
if param.shape == param_.shape:
|
||||||
|
if torch.dist(param, param_) < 1e-6:
|
||||||
|
return name
|
||||||
|
else:
|
||||||
|
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
||||||
|
return name
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
||||||
|
matched_keys = set()
|
||||||
|
with torch.no_grad():
|
||||||
|
for name in source_state_dict:
|
||||||
|
rename = search_parameter(source_state_dict[name], target_state_dict)
|
||||||
|
if rename is not None:
|
||||||
|
print(f'"{name}": "{rename}",')
|
||||||
|
matched_keys.add(rename)
|
||||||
|
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
||||||
|
length = source_state_dict[name].shape[0] // 3
|
||||||
|
rename = []
|
||||||
|
for i in range(3):
|
||||||
|
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
||||||
|
if None not in rename:
|
||||||
|
print(f'"{name}": {rename},')
|
||||||
|
for rename_ in rename:
|
||||||
|
matched_keys.add(rename_)
|
||||||
|
for name in target_state_dict:
|
||||||
|
if name not in matched_keys:
|
||||||
|
print("Cannot find", name, target_state_dict[name].shape)
|
||||||
|
|||||||
@@ -1,408 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from .sd3_dit import TimestepEmbeddings
|
|
||||||
from .attention import Attention
|
|
||||||
from .utils import load_state_dict_from_folder
|
|
||||||
from .tiler import TileWorker2Dto3D
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogPatchify(torch.nn.Module):
|
|
||||||
def __init__(self, dim_in, dim_out, patch_size) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.proj(hidden_states)
|
|
||||||
hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C")
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogAdaLayerNorm(torch.nn.Module):
|
|
||||||
def __init__(self, dim, dim_cond, single=False):
|
|
||||||
super().__init__()
|
|
||||||
self.single = single
|
|
||||||
self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6))
|
|
||||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, prompt_emb, emb):
|
|
||||||
emb = self.linear(torch.nn.functional.silu(emb))
|
|
||||||
if self.single:
|
|
||||||
shift, scale = emb.unsqueeze(1).chunk(2, dim=2)
|
|
||||||
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
|
|
||||||
return hidden_states
|
|
||||||
else:
|
|
||||||
shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2)
|
|
||||||
hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a
|
|
||||||
prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b
|
|
||||||
return hidden_states, prompt_emb, gate_a, gate_b
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogDiTBlock(torch.nn.Module):
|
|
||||||
def __init__(self, dim, dim_cond, num_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1 = CogAdaLayerNorm(dim, dim_cond)
|
|
||||||
self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
|
||||||
self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
|
||||||
self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
|
||||||
|
|
||||||
self.norm2 = CogAdaLayerNorm(dim, dim_cond)
|
|
||||||
self.ff = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim, dim*4),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(dim*4, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(self, x, freqs_cis):
|
|
||||||
cos, sin = freqs_cis # [S, D]
|
|
||||||
cos = cos[None, None]
|
|
||||||
sin = sin[None, None]
|
|
||||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
|
||||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
|
||||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
|
||||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
|
|
||||||
q = self.norm_q(q)
|
|
||||||
k = self.norm_k(k)
|
|
||||||
q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb)
|
|
||||||
k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb)
|
|
||||||
return q, k, v
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
|
|
||||||
# Attention
|
|
||||||
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1(
|
|
||||||
hidden_states, prompt_emb, time_emb
|
|
||||||
)
|
|
||||||
attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
|
||||||
attention_io = self.attn1(
|
|
||||||
attention_io,
|
|
||||||
qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1])
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:]
|
|
||||||
prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]]
|
|
||||||
|
|
||||||
# Feed forward
|
|
||||||
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2(
|
|
||||||
hidden_states, prompt_emb, time_emb
|
|
||||||
)
|
|
||||||
ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
|
||||||
ff_io = self.ff(ff_io)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:]
|
|
||||||
prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]]
|
|
||||||
|
|
||||||
return hidden_states, prompt_emb
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogDiT(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.patchify = CogPatchify(16, 3072, 2)
|
|
||||||
self.time_embedder = TimestepEmbeddings(3072, 512)
|
|
||||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
|
||||||
self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)])
|
|
||||||
self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
|
|
||||||
self.norm_out = CogAdaLayerNorm(3072, 512, single=True)
|
|
||||||
self.proj_out = torch.nn.Linear(3072, 64, bias=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
|
|
||||||
tw = tgt_width
|
|
||||||
th = tgt_height
|
|
||||||
h, w = src
|
|
||||||
r = h / w
|
|
||||||
if r > (th / tw):
|
|
||||||
resize_height = th
|
|
||||||
resize_width = int(round(th / h * w))
|
|
||||||
else:
|
|
||||||
resize_width = tw
|
|
||||||
resize_height = int(round(tw / w * h))
|
|
||||||
|
|
||||||
crop_top = int(round((th - resize_height) / 2.0))
|
|
||||||
crop_left = int(round((tw - resize_width) / 2.0))
|
|
||||||
|
|
||||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
|
||||||
|
|
||||||
|
|
||||||
def get_3d_rotary_pos_embed(
|
|
||||||
self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
|
||||||
):
|
|
||||||
start, stop = crops_coords
|
|
||||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
|
||||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
|
||||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
|
||||||
|
|
||||||
# Compute dimensions for each axis
|
|
||||||
dim_t = embed_dim // 4
|
|
||||||
dim_h = embed_dim // 8 * 3
|
|
||||||
dim_w = embed_dim // 8 * 3
|
|
||||||
|
|
||||||
# Temporal frequencies
|
|
||||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
|
||||||
grid_t = torch.from_numpy(grid_t).float()
|
|
||||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
|
||||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
|
||||||
|
|
||||||
# Spatial frequencies for height and width
|
|
||||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
|
||||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
|
||||||
grid_h = torch.from_numpy(grid_h).float()
|
|
||||||
grid_w = torch.from_numpy(grid_w).float()
|
|
||||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
|
||||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
|
||||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
|
||||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
|
||||||
|
|
||||||
# Broadcast and concatenate tensors along specified dimension
|
|
||||||
def broadcast(tensors, dim=-1):
|
|
||||||
num_tensors = len(tensors)
|
|
||||||
shape_lens = {len(t.shape) for t in tensors}
|
|
||||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
|
||||||
shape_len = list(shape_lens)[0]
|
|
||||||
dim = (dim + shape_len) if dim < 0 else dim
|
|
||||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
|
||||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
|
||||||
assert all(
|
|
||||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
|
||||||
), "invalid dimensions for broadcastable concatenation"
|
|
||||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
|
||||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
|
||||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
|
||||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
|
||||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
|
||||||
return torch.cat(tensors, dim=dim)
|
|
||||||
|
|
||||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
|
||||||
|
|
||||||
t, h, w, d = freqs.shape
|
|
||||||
freqs = freqs.view(t * h * w, d)
|
|
||||||
|
|
||||||
# Generate sine and cosine components
|
|
||||||
sin = freqs.sin()
|
|
||||||
cos = freqs.cos()
|
|
||||||
|
|
||||||
if use_real:
|
|
||||||
return cos, sin
|
|
||||||
else:
|
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
|
||||||
return freqs_cis
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_rotary_positional_embeddings(
|
|
||||||
self,
|
|
||||||
height: int,
|
|
||||||
width: int,
|
|
||||||
num_frames: int,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
grid_height = height // 2
|
|
||||||
grid_width = width // 2
|
|
||||||
base_size_width = 720 // (8 * 2)
|
|
||||||
base_size_height = 480 // (8 * 2)
|
|
||||||
|
|
||||||
grid_crops_coords = self.get_resize_crop_region_for_grid(
|
|
||||||
(grid_height, grid_width), base_size_width, base_size_height
|
|
||||||
)
|
|
||||||
freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed(
|
|
||||||
embed_dim=64,
|
|
||||||
crops_coords=grid_crops_coords,
|
|
||||||
grid_size=(grid_height, grid_width),
|
|
||||||
temporal_size=num_frames,
|
|
||||||
use_real=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
freqs_cos = freqs_cos.to(device=device)
|
|
||||||
freqs_sin = freqs_sin.to(device=device)
|
|
||||||
return freqs_cos, freqs_sin
|
|
||||||
|
|
||||||
|
|
||||||
def unpatchify(self, hidden_states, height, width):
|
|
||||||
hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def build_mask(self, T, H, W, dtype, device, is_bound):
|
|
||||||
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
|
||||||
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
|
||||||
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
|
||||||
border_width = (H + W) // 4
|
|
||||||
pad = torch.ones_like(h) * border_width
|
|
||||||
mask = torch.stack([
|
|
||||||
pad if is_bound[0] else t + 1,
|
|
||||||
pad if is_bound[1] else T - t,
|
|
||||||
pad if is_bound[2] else h + 1,
|
|
||||||
pad if is_bound[3] else H - h,
|
|
||||||
pad if is_bound[4] else w + 1,
|
|
||||||
pad if is_bound[5] else W - w
|
|
||||||
]).min(dim=0).values
|
|
||||||
mask = mask.clip(1, border_width)
|
|
||||||
mask = (mask / border_width).to(dtype=dtype, device=device)
|
|
||||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)):
|
|
||||||
B, C, T, H, W = hidden_states.shape
|
|
||||||
value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
|
||||||
weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
|
||||||
|
|
||||||
# Split tasks
|
|
||||||
tasks = []
|
|
||||||
for h in range(0, H, tile_stride):
|
|
||||||
for w in range(0, W, tile_stride):
|
|
||||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
|
||||||
continue
|
|
||||||
h_, w_ = h + tile_size, w + tile_size
|
|
||||||
if h_ > H: h, h_ = max(H - tile_size, 0), H
|
|
||||||
if w_ > W: w, w_ = max(W - tile_size, 0), W
|
|
||||||
tasks.append((h, h_, w, w_))
|
|
||||||
|
|
||||||
# Run
|
|
||||||
for hl, hr, wl, wr in tasks:
|
|
||||||
mask = self.build_mask(
|
|
||||||
value.shape[2], (hr-hl), (wr-wl),
|
|
||||||
hidden_states.dtype, hidden_states.device,
|
|
||||||
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W)
|
|
||||||
)
|
|
||||||
model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb)
|
|
||||||
value[:, :, :, hl:hr, wl:wr] += model_output * mask
|
|
||||||
weight[:, :, :, hl:hr, wl:wr] += mask
|
|
||||||
value = value / weight
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False):
|
|
||||||
if tiled:
|
|
||||||
return TileWorker2Dto3D().tiled_forward(
|
|
||||||
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
|
|
||||||
model_input=hidden_states,
|
|
||||||
tile_size=tile_size, tile_stride=tile_stride,
|
|
||||||
tile_device=hidden_states.device, tile_dtype=hidden_states.dtype,
|
|
||||||
computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype
|
|
||||||
)
|
|
||||||
num_frames, height, width = hidden_states.shape[-3:]
|
|
||||||
if image_rotary_emb is None:
|
|
||||||
image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device)
|
|
||||||
hidden_states = self.patchify(hidden_states)
|
|
||||||
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
|
|
||||||
prompt_emb = self.context_embedder(prompt_emb)
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
if self.training and use_gradient_checkpointing:
|
|
||||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states, prompt_emb, time_emb, image_rotary_emb,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
||||||
hidden_states = self.norm_final(hidden_states)
|
|
||||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
|
||||||
hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb)
|
|
||||||
hidden_states = self.proj_out(hidden_states)
|
|
||||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return CogDiTStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_pretrained(file_path, torch_dtype=torch.bfloat16):
|
|
||||||
model = CogDiT().to(torch_dtype)
|
|
||||||
state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype)
|
|
||||||
state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogDiTStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"patch_embed.proj.weight": "patchify.proj.weight",
|
|
||||||
"patch_embed.proj.bias": "patchify.proj.bias",
|
|
||||||
"patch_embed.text_proj.weight": "context_embedder.weight",
|
|
||||||
"patch_embed.text_proj.bias": "context_embedder.bias",
|
|
||||||
"time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight",
|
|
||||||
"time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias",
|
|
||||||
"time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight",
|
|
||||||
"time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias",
|
|
||||||
|
|
||||||
"norm_final.weight": "norm_final.weight",
|
|
||||||
"norm_final.bias": "norm_final.bias",
|
|
||||||
"norm_out.linear.weight": "norm_out.linear.weight",
|
|
||||||
"norm_out.linear.bias": "norm_out.linear.bias",
|
|
||||||
"norm_out.norm.weight": "norm_out.norm.weight",
|
|
||||||
"norm_out.norm.bias": "norm_out.norm.bias",
|
|
||||||
"proj_out.weight": "proj_out.weight",
|
|
||||||
"proj_out.bias": "proj_out.bias",
|
|
||||||
}
|
|
||||||
suffix_dict = {
|
|
||||||
"norm1.linear.weight": "norm1.linear.weight",
|
|
||||||
"norm1.linear.bias": "norm1.linear.bias",
|
|
||||||
"norm1.norm.weight": "norm1.norm.weight",
|
|
||||||
"norm1.norm.bias": "norm1.norm.bias",
|
|
||||||
"attn1.norm_q.weight": "norm_q.weight",
|
|
||||||
"attn1.norm_q.bias": "norm_q.bias",
|
|
||||||
"attn1.norm_k.weight": "norm_k.weight",
|
|
||||||
"attn1.norm_k.bias": "norm_k.bias",
|
|
||||||
"attn1.to_q.weight": "attn1.to_q.weight",
|
|
||||||
"attn1.to_q.bias": "attn1.to_q.bias",
|
|
||||||
"attn1.to_k.weight": "attn1.to_k.weight",
|
|
||||||
"attn1.to_k.bias": "attn1.to_k.bias",
|
|
||||||
"attn1.to_v.weight": "attn1.to_v.weight",
|
|
||||||
"attn1.to_v.bias": "attn1.to_v.bias",
|
|
||||||
"attn1.to_out.0.weight": "attn1.to_out.weight",
|
|
||||||
"attn1.to_out.0.bias": "attn1.to_out.bias",
|
|
||||||
"norm2.linear.weight": "norm2.linear.weight",
|
|
||||||
"norm2.linear.bias": "norm2.linear.bias",
|
|
||||||
"norm2.norm.weight": "norm2.norm.weight",
|
|
||||||
"norm2.norm.bias": "norm2.norm.bias",
|
|
||||||
"ff.net.0.proj.weight": "ff.0.weight",
|
|
||||||
"ff.net.0.proj.bias": "ff.0.bias",
|
|
||||||
"ff.net.2.weight": "ff.2.weight",
|
|
||||||
"ff.net.2.bias": "ff.2.bias",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name in rename_dict:
|
|
||||||
if name == "patch_embed.proj.weight":
|
|
||||||
param = param.unsqueeze(2)
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
else:
|
|
||||||
names = name.split(".")
|
|
||||||
if names[0] == "transformer_blocks":
|
|
||||||
suffix = ".".join(names[2:])
|
|
||||||
state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
@@ -1,518 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from .tiler import TileWorker2Dto3D
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample3D(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: int = 3,
|
|
||||||
stride: int = 2,
|
|
||||||
padding: int = 0,
|
|
||||||
compress_time: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
||||||
self.compress_time = compress_time
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
|
|
||||||
if self.compress_time:
|
|
||||||
batch_size, channels, frames, height, width = x.shape
|
|
||||||
|
|
||||||
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
|
||||||
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
|
||||||
|
|
||||||
if x.shape[-1] % 2 == 1:
|
|
||||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
|
||||||
if x_rest.shape[-1] > 0:
|
|
||||||
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
|
||||||
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
|
||||||
|
|
||||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
|
||||||
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
|
||||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
|
||||||
else:
|
|
||||||
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
|
||||||
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
|
||||||
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
|
||||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
|
||||||
|
|
||||||
# Pad the tensor
|
|
||||||
pad = (0, 1, 0, 1)
|
|
||||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
||||||
batch_size, channels, frames, height, width = x.shape
|
|
||||||
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
|
||||||
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
|
||||||
x = self.conv(x)
|
|
||||||
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
|
||||||
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample3D(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: int = 3,
|
|
||||||
stride: int = 1,
|
|
||||||
padding: int = 1,
|
|
||||||
compress_time: bool = False,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
||||||
self.compress_time = compress_time
|
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
|
|
||||||
if self.compress_time:
|
|
||||||
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
|
||||||
# split first frame
|
|
||||||
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
|
||||||
|
|
||||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0)
|
|
||||||
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0)
|
|
||||||
x_first = x_first[:, :, None, :, :]
|
|
||||||
inputs = torch.cat([x_first, x_rest], dim=2)
|
|
||||||
elif inputs.shape[2] > 1:
|
|
||||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
|
||||||
else:
|
|
||||||
inputs = inputs.squeeze(2)
|
|
||||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
|
||||||
inputs = inputs[:, :, None, :, :]
|
|
||||||
else:
|
|
||||||
# only interpolate 2D
|
|
||||||
b, c, t, h, w = inputs.shape
|
|
||||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
|
||||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
|
||||||
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
|
||||||
|
|
||||||
b, c, t, h, w = inputs.shape
|
|
||||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
|
||||||
inputs = self.conv(inputs)
|
|
||||||
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
|
||||||
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogVideoXSpatialNorm3D(torch.nn.Module):
|
|
||||||
def __init__(self, f_channels, zq_channels, groups):
|
|
||||||
super().__init__()
|
|
||||||
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
|
||||||
self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
|
||||||
self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
|
||||||
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
|
||||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
|
||||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
|
||||||
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
|
|
||||||
z_first = torch.nn.functional.interpolate(z_first, size=f_first_size)
|
|
||||||
z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size)
|
|
||||||
zq = torch.cat([z_first, z_rest], dim=2)
|
|
||||||
else:
|
|
||||||
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:])
|
|
||||||
|
|
||||||
norm_f = self.norm_layer(f)
|
|
||||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
|
||||||
return new_f
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Resnet3DBlock(torch.nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False):
|
|
||||||
super().__init__()
|
|
||||||
self.nonlinearity = torch.nn.SiLU()
|
|
||||||
if spatial_norm_dim is None:
|
|
||||||
self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
|
||||||
self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
|
||||||
else:
|
|
||||||
self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups)
|
|
||||||
self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups)
|
|
||||||
|
|
||||||
self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
|
||||||
|
|
||||||
self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
|
||||||
|
|
||||||
if in_channels != out_channels:
|
|
||||||
if use_conv_shortcut:
|
|
||||||
self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
|
||||||
else:
|
|
||||||
self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1)
|
|
||||||
else:
|
|
||||||
self.conv_shortcut = lambda x: x
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, zq):
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states)
|
|
||||||
hidden_states = self.nonlinearity(hidden_states)
|
|
||||||
hidden_states = self.conv1(hidden_states)
|
|
||||||
|
|
||||||
hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states)
|
|
||||||
hidden_states = self.nonlinearity(hidden_states)
|
|
||||||
hidden_states = self.conv2(hidden_states)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + self.conv_shortcut(residual)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CachedConv3d(torch.nn.Conv3d):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
|
|
||||||
super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
||||||
self.cached_tensor = None
|
|
||||||
|
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
self.cached_tensor = None
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor:
|
|
||||||
if use_cache:
|
|
||||||
if self.cached_tensor is None:
|
|
||||||
self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2)
|
|
||||||
input = torch.concat([self.cached_tensor, input], dim=2)
|
|
||||||
self.cached_tensor = input[:, :, -2:]
|
|
||||||
return super().forward(input)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogVAEDecoder(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.scaling_factor = 0.7
|
|
||||||
self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1))
|
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([
|
|
||||||
Resnet3DBlock(512, 512, 16, 32),
|
|
||||||
Resnet3DBlock(512, 512, 16, 32),
|
|
||||||
Resnet3DBlock(512, 512, 16, 32),
|
|
||||||
Resnet3DBlock(512, 512, 16, 32),
|
|
||||||
Resnet3DBlock(512, 512, 16, 32),
|
|
||||||
Resnet3DBlock(512, 512, 16, 32),
|
|
||||||
Upsample3D(512, 512, compress_time=True),
|
|
||||||
Resnet3DBlock(512, 256, 16, 32),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Upsample3D(256, 256, compress_time=True),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Resnet3DBlock(256, 256, 16, 32),
|
|
||||||
Upsample3D(256, 256, compress_time=False),
|
|
||||||
Resnet3DBlock(256, 128, 16, 32),
|
|
||||||
Resnet3DBlock(128, 128, 16, 32),
|
|
||||||
Resnet3DBlock(128, 128, 16, 32),
|
|
||||||
Resnet3DBlock(128, 128, 16, 32),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32)
|
|
||||||
self.conv_act = torch.nn.SiLU()
|
|
||||||
self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1))
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, sample):
|
|
||||||
sample = sample / self.scaling_factor
|
|
||||||
hidden_states = self.conv_in(sample)
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
hidden_states = block(hidden_states, sample)
|
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, sample)
|
|
||||||
hidden_states = self.conv_act(hidden_states)
|
|
||||||
hidden_states = self.conv_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
|
|
||||||
if tiled:
|
|
||||||
B, C, T, H, W = sample.shape
|
|
||||||
return TileWorker2Dto3D().tiled_forward(
|
|
||||||
forward_fn=lambda x: self.decode_small_video(x),
|
|
||||||
model_input=sample,
|
|
||||||
tile_size=tile_size, tile_stride=tile_stride,
|
|
||||||
tile_device=sample.device, tile_dtype=sample.dtype,
|
|
||||||
computation_device=sample.device, computation_dtype=sample.dtype,
|
|
||||||
scales=(3/16, (T//2*8+T%2)/T, 8, 8),
|
|
||||||
progress_bar=progress_bar
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self.decode_small_video(sample)
|
|
||||||
|
|
||||||
|
|
||||||
def decode_small_video(self, sample):
|
|
||||||
B, C, T, H, W = sample.shape
|
|
||||||
computation_device = self.conv_in.weight.device
|
|
||||||
computation_dtype = self.conv_in.weight.dtype
|
|
||||||
value = []
|
|
||||||
for i in range(T//2):
|
|
||||||
tl = i*2 + T%2 - (T%2 and i==0)
|
|
||||||
tr = i*2 + 2 + T%2
|
|
||||||
model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device)
|
|
||||||
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
|
|
||||||
value.append(model_output)
|
|
||||||
value = torch.concat(value, dim=2)
|
|
||||||
for name, module in self.named_modules():
|
|
||||||
if isinstance(module, CachedConv3d):
|
|
||||||
module.clear_cache()
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return CogVAEDecoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogVAEEncoder(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.scaling_factor = 0.7
|
|
||||||
self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1))
|
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([
|
|
||||||
Resnet3DBlock(128, 128, None, 32),
|
|
||||||
Resnet3DBlock(128, 128, None, 32),
|
|
||||||
Resnet3DBlock(128, 128, None, 32),
|
|
||||||
Downsample3D(128, 128, compress_time=True),
|
|
||||||
Resnet3DBlock(128, 256, None, 32),
|
|
||||||
Resnet3DBlock(256, 256, None, 32),
|
|
||||||
Resnet3DBlock(256, 256, None, 32),
|
|
||||||
Downsample3D(256, 256, compress_time=True),
|
|
||||||
Resnet3DBlock(256, 256, None, 32),
|
|
||||||
Resnet3DBlock(256, 256, None, 32),
|
|
||||||
Resnet3DBlock(256, 256, None, 32),
|
|
||||||
Downsample3D(256, 256, compress_time=False),
|
|
||||||
Resnet3DBlock(256, 512, None, 32),
|
|
||||||
Resnet3DBlock(512, 512, None, 32),
|
|
||||||
Resnet3DBlock(512, 512, None, 32),
|
|
||||||
Resnet3DBlock(512, 512, None, 32),
|
|
||||||
Resnet3DBlock(512, 512, None, 32),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True)
|
|
||||||
self.conv_act = torch.nn.SiLU()
|
|
||||||
self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1))
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, sample):
|
|
||||||
hidden_states = self.conv_in(sample)
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
hidden_states = block(hidden_states, sample)
|
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states)
|
|
||||||
hidden_states = self.conv_act(hidden_states)
|
|
||||||
hidden_states = self.conv_out(hidden_states)[:, :16]
|
|
||||||
hidden_states = hidden_states * self.scaling_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
|
|
||||||
if tiled:
|
|
||||||
B, C, T, H, W = sample.shape
|
|
||||||
return TileWorker2Dto3D().tiled_forward(
|
|
||||||
forward_fn=lambda x: self.encode_small_video(x),
|
|
||||||
model_input=sample,
|
|
||||||
tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride),
|
|
||||||
tile_device=sample.device, tile_dtype=sample.dtype,
|
|
||||||
computation_device=sample.device, computation_dtype=sample.dtype,
|
|
||||||
scales=(16/3, (T//4+T%2)/T, 1/8, 1/8),
|
|
||||||
progress_bar=progress_bar
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self.encode_small_video(sample)
|
|
||||||
|
|
||||||
|
|
||||||
def encode_small_video(self, sample):
|
|
||||||
B, C, T, H, W = sample.shape
|
|
||||||
computation_device = self.conv_in.weight.device
|
|
||||||
computation_dtype = self.conv_in.weight.dtype
|
|
||||||
value = []
|
|
||||||
for i in range(T//8):
|
|
||||||
t = i*8 + T%2 - (T%2 and i==0)
|
|
||||||
t_ = i*8 + 8 + T%2
|
|
||||||
model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device)
|
|
||||||
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
|
|
||||||
value.append(model_output)
|
|
||||||
value = torch.concat(value, dim=2)
|
|
||||||
for name, module in self.named_modules():
|
|
||||||
if isinstance(module, CachedConv3d):
|
|
||||||
module.clear_cache()
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return CogVAEEncoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogVAEEncoderStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"encoder.conv_in.conv.weight": "conv_in.weight",
|
|
||||||
"encoder.conv_in.conv.bias": "conv_in.bias",
|
|
||||||
"encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight",
|
|
||||||
"encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias",
|
|
||||||
"encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight",
|
|
||||||
"encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias",
|
|
||||||
"encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight",
|
|
||||||
"encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias",
|
|
||||||
"encoder.norm_out.weight": "norm_out.weight",
|
|
||||||
"encoder.norm_out.bias": "norm_out.bias",
|
|
||||||
"encoder.conv_out.conv.weight": "conv_out.weight",
|
|
||||||
"encoder.conv_out.conv.bias": "conv_out.bias",
|
|
||||||
}
|
|
||||||
prefix_dict = {
|
|
||||||
"encoder.down_blocks.0.resnets.0.": "blocks.0.",
|
|
||||||
"encoder.down_blocks.0.resnets.1.": "blocks.1.",
|
|
||||||
"encoder.down_blocks.0.resnets.2.": "blocks.2.",
|
|
||||||
"encoder.down_blocks.1.resnets.0.": "blocks.4.",
|
|
||||||
"encoder.down_blocks.1.resnets.1.": "blocks.5.",
|
|
||||||
"encoder.down_blocks.1.resnets.2.": "blocks.6.",
|
|
||||||
"encoder.down_blocks.2.resnets.0.": "blocks.8.",
|
|
||||||
"encoder.down_blocks.2.resnets.1.": "blocks.9.",
|
|
||||||
"encoder.down_blocks.2.resnets.2.": "blocks.10.",
|
|
||||||
"encoder.down_blocks.3.resnets.0.": "blocks.12.",
|
|
||||||
"encoder.down_blocks.3.resnets.1.": "blocks.13.",
|
|
||||||
"encoder.down_blocks.3.resnets.2.": "blocks.14.",
|
|
||||||
"encoder.mid_block.resnets.0.": "blocks.15.",
|
|
||||||
"encoder.mid_block.resnets.1.": "blocks.16.",
|
|
||||||
}
|
|
||||||
suffix_dict = {
|
|
||||||
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
|
|
||||||
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
|
|
||||||
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
|
|
||||||
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
|
|
||||||
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
|
|
||||||
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
|
|
||||||
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
|
|
||||||
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
|
|
||||||
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
|
|
||||||
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
|
|
||||||
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
|
|
||||||
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
|
|
||||||
"conv1.conv.weight": "conv1.weight",
|
|
||||||
"conv1.conv.bias": "conv1.bias",
|
|
||||||
"conv2.conv.weight": "conv2.weight",
|
|
||||||
"conv2.conv.bias": "conv2.bias",
|
|
||||||
"conv_shortcut.weight": "conv_shortcut.weight",
|
|
||||||
"conv_shortcut.bias": "conv_shortcut.bias",
|
|
||||||
"norm1.weight": "norm1.weight",
|
|
||||||
"norm1.bias": "norm1.bias",
|
|
||||||
"norm2.weight": "norm2.weight",
|
|
||||||
"norm2.bias": "norm2.bias",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name in rename_dict:
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
else:
|
|
||||||
for prefix in prefix_dict:
|
|
||||||
if name.startswith(prefix):
|
|
||||||
suffix = name[len(prefix):]
|
|
||||||
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogVAEDecoderStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"decoder.conv_in.conv.weight": "conv_in.weight",
|
|
||||||
"decoder.conv_in.conv.bias": "conv_in.bias",
|
|
||||||
"decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight",
|
|
||||||
"decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias",
|
|
||||||
"decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight",
|
|
||||||
"decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias",
|
|
||||||
"decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight",
|
|
||||||
"decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias",
|
|
||||||
"decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight",
|
|
||||||
"decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias",
|
|
||||||
"decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight",
|
|
||||||
"decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias",
|
|
||||||
"decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight",
|
|
||||||
"decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias",
|
|
||||||
"decoder.conv_out.conv.weight": "conv_out.weight",
|
|
||||||
"decoder.conv_out.conv.bias": "conv_out.bias"
|
|
||||||
}
|
|
||||||
prefix_dict = {
|
|
||||||
"decoder.mid_block.resnets.0.": "blocks.0.",
|
|
||||||
"decoder.mid_block.resnets.1.": "blocks.1.",
|
|
||||||
"decoder.up_blocks.0.resnets.0.": "blocks.2.",
|
|
||||||
"decoder.up_blocks.0.resnets.1.": "blocks.3.",
|
|
||||||
"decoder.up_blocks.0.resnets.2.": "blocks.4.",
|
|
||||||
"decoder.up_blocks.0.resnets.3.": "blocks.5.",
|
|
||||||
"decoder.up_blocks.1.resnets.0.": "blocks.7.",
|
|
||||||
"decoder.up_blocks.1.resnets.1.": "blocks.8.",
|
|
||||||
"decoder.up_blocks.1.resnets.2.": "blocks.9.",
|
|
||||||
"decoder.up_blocks.1.resnets.3.": "blocks.10.",
|
|
||||||
"decoder.up_blocks.2.resnets.0.": "blocks.12.",
|
|
||||||
"decoder.up_blocks.2.resnets.1.": "blocks.13.",
|
|
||||||
"decoder.up_blocks.2.resnets.2.": "blocks.14.",
|
|
||||||
"decoder.up_blocks.2.resnets.3.": "blocks.15.",
|
|
||||||
"decoder.up_blocks.3.resnets.0.": "blocks.17.",
|
|
||||||
"decoder.up_blocks.3.resnets.1.": "blocks.18.",
|
|
||||||
"decoder.up_blocks.3.resnets.2.": "blocks.19.",
|
|
||||||
"decoder.up_blocks.3.resnets.3.": "blocks.20.",
|
|
||||||
}
|
|
||||||
suffix_dict = {
|
|
||||||
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
|
|
||||||
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
|
|
||||||
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
|
|
||||||
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
|
|
||||||
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
|
|
||||||
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
|
|
||||||
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
|
|
||||||
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
|
|
||||||
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
|
|
||||||
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
|
|
||||||
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
|
|
||||||
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
|
|
||||||
"conv1.conv.weight": "conv1.weight",
|
|
||||||
"conv1.conv.bias": "conv1.bias",
|
|
||||||
"conv2.conv.weight": "conv2.weight",
|
|
||||||
"conv2.conv.bias": "conv2.bias",
|
|
||||||
"conv_shortcut.weight": "conv_shortcut.weight",
|
|
||||||
"conv_shortcut.bias": "conv_shortcut.bias",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name in rename_dict:
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
else:
|
|
||||||
for prefix in prefix_dict:
|
|
||||||
if name.startswith(prefix):
|
|
||||||
suffix = name[len(prefix):]
|
|
||||||
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
|
|
||||||
@@ -1,111 +0,0 @@
|
|||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
import os, shutil
|
|
||||||
from typing_extensions import Literal, TypeAlias
|
|
||||||
from typing import List
|
|
||||||
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
|
|
||||||
|
|
||||||
|
|
||||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
|
||||||
file_name = os.path.basename(origin_file_path)
|
|
||||||
if file_name in os.listdir(local_dir):
|
|
||||||
print(f" {file_name} has been already in {local_dir}.")
|
|
||||||
else:
|
|
||||||
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
|
||||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
|
||||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
|
||||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
|
||||||
if downloaded_file_path != target_file_path:
|
|
||||||
shutil.move(downloaded_file_path, target_file_path)
|
|
||||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
|
||||||
|
|
||||||
|
|
||||||
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
|
||||||
file_name = os.path.basename(origin_file_path)
|
|
||||||
if file_name in os.listdir(local_dir):
|
|
||||||
print(f" {file_name} has been already in {local_dir}.")
|
|
||||||
else:
|
|
||||||
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
|
||||||
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
|
||||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
|
||||||
target_file_path = os.path.join(local_dir, file_name)
|
|
||||||
if downloaded_file_path != target_file_path:
|
|
||||||
shutil.move(downloaded_file_path, target_file_path)
|
|
||||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
|
||||||
|
|
||||||
|
|
||||||
Preset_model_website: TypeAlias = Literal[
|
|
||||||
"HuggingFace",
|
|
||||||
"ModelScope",
|
|
||||||
]
|
|
||||||
website_to_preset_models = {
|
|
||||||
"HuggingFace": preset_models_on_huggingface,
|
|
||||||
"ModelScope": preset_models_on_modelscope,
|
|
||||||
}
|
|
||||||
website_to_download_fn = {
|
|
||||||
"HuggingFace": download_from_huggingface,
|
|
||||||
"ModelScope": download_from_modelscope,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def download_customized_models(
|
|
||||||
model_id,
|
|
||||||
origin_file_path,
|
|
||||||
local_dir,
|
|
||||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
|
||||||
):
|
|
||||||
downloaded_files = []
|
|
||||||
for website in downloading_priority:
|
|
||||||
# Check if the file is downloaded.
|
|
||||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
|
||||||
if file_to_download in downloaded_files:
|
|
||||||
continue
|
|
||||||
# Download
|
|
||||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
|
||||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
|
||||||
downloaded_files.append(file_to_download)
|
|
||||||
return downloaded_files
|
|
||||||
|
|
||||||
|
|
||||||
def download_models(
|
|
||||||
model_id_list: List[Preset_model_id] = [],
|
|
||||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
|
||||||
):
|
|
||||||
print(f"Downloading models: {model_id_list}")
|
|
||||||
downloaded_files = []
|
|
||||||
load_files = []
|
|
||||||
|
|
||||||
for model_id in model_id_list:
|
|
||||||
for website in downloading_priority:
|
|
||||||
if model_id in website_to_preset_models[website]:
|
|
||||||
|
|
||||||
# Parse model metadata
|
|
||||||
model_metadata = website_to_preset_models[website][model_id]
|
|
||||||
if isinstance(model_metadata, list):
|
|
||||||
file_data = model_metadata
|
|
||||||
else:
|
|
||||||
file_data = model_metadata.get("file_list", [])
|
|
||||||
|
|
||||||
# Try downloading the model from this website.
|
|
||||||
model_files = []
|
|
||||||
for model_id, origin_file_path, local_dir in file_data:
|
|
||||||
# Check if the file is downloaded.
|
|
||||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
|
||||||
if file_to_download in downloaded_files:
|
|
||||||
continue
|
|
||||||
# Download
|
|
||||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
|
||||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
|
||||||
downloaded_files.append(file_to_download)
|
|
||||||
model_files.append(file_to_download)
|
|
||||||
|
|
||||||
# If the model is successfully downloaded, break.
|
|
||||||
if len(model_files) > 0:
|
|
||||||
if isinstance(model_metadata, dict) and "load_path" in model_metadata:
|
|
||||||
model_files = model_metadata["load_path"]
|
|
||||||
load_files.extend(model_files)
|
|
||||||
break
|
|
||||||
|
|
||||||
return load_files
|
|
||||||
@@ -1,327 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
|
|
||||||
from .utils import hash_state_dict_keys, init_weights_on_device
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxControlNet(torch.nn.Module):
|
|
||||||
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
|
||||||
super().__init__()
|
|
||||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
|
||||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
|
||||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
|
||||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
|
||||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
|
||||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
|
|
||||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
|
|
||||||
|
|
||||||
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
|
|
||||||
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
|
|
||||||
|
|
||||||
self.mode_dict = mode_dict
|
|
||||||
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
|
|
||||||
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_image_ids(self, latents):
|
|
||||||
batch_size, _, height, width = latents.shape
|
|
||||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
|
||||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
|
||||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
|
||||||
|
|
||||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
|
||||||
|
|
||||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
|
||||||
latent_image_ids = latent_image_ids.reshape(
|
|
||||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
|
||||||
)
|
|
||||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
|
||||||
|
|
||||||
return latent_image_ids
|
|
||||||
|
|
||||||
|
|
||||||
def patchify(self, hidden_states):
|
|
||||||
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
|
|
||||||
if len(res_stack) == 0:
|
|
||||||
return [torch.zeros_like(hidden_states)] * num_blocks
|
|
||||||
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
|
|
||||||
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
|
|
||||||
return aligned_res_stack
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
controlnet_conditioning,
|
|
||||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
|
||||||
processor_id=None,
|
|
||||||
tiled=False, tile_size=128, tile_stride=64,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
if image_ids is None:
|
|
||||||
image_ids = self.prepare_image_ids(hidden_states)
|
|
||||||
|
|
||||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
|
||||||
if self.guidance_embedder is not None:
|
|
||||||
guidance = guidance * 1000
|
|
||||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
|
||||||
prompt_emb = self.context_embedder(prompt_emb)
|
|
||||||
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
|
|
||||||
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
|
|
||||||
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
|
|
||||||
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
|
|
||||||
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
|
|
||||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
|
||||||
|
|
||||||
hidden_states = self.patchify(hidden_states)
|
|
||||||
hidden_states = self.x_embedder(hidden_states)
|
|
||||||
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
|
|
||||||
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
|
|
||||||
|
|
||||||
controlnet_res_stack = []
|
|
||||||
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
|
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
|
||||||
controlnet_res_stack.append(controlnet_block(hidden_states))
|
|
||||||
|
|
||||||
controlnet_single_res_stack = []
|
|
||||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
||||||
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
|
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
|
||||||
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
|
|
||||||
|
|
||||||
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
|
|
||||||
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
|
|
||||||
|
|
||||||
return controlnet_res_stack, controlnet_single_res_stack
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return FluxControlNetStateDictConverter()
|
|
||||||
|
|
||||||
def quantize(self):
|
|
||||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
|
||||||
if device is None or weight.device == device:
|
|
||||||
if not copy:
|
|
||||||
if dtype is None or weight.dtype == dtype:
|
|
||||||
return weight
|
|
||||||
return weight.to(dtype=dtype, copy=copy)
|
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
r.copy_(weight)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def cast_weight(s, input=None, dtype=None, device=None):
|
|
||||||
if input is not None:
|
|
||||||
if dtype is None:
|
|
||||||
dtype = input.dtype
|
|
||||||
if device is None:
|
|
||||||
device = input.device
|
|
||||||
weight = cast_to(s.weight, dtype, device)
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|
||||||
if input is not None:
|
|
||||||
if dtype is None:
|
|
||||||
dtype = input.dtype
|
|
||||||
if bias_dtype is None:
|
|
||||||
bias_dtype = dtype
|
|
||||||
if device is None:
|
|
||||||
device = input.device
|
|
||||||
bias = None
|
|
||||||
weight = cast_to(s.weight, dtype, device)
|
|
||||||
bias = cast_to(s.bias, bias_dtype, device)
|
|
||||||
return weight, bias
|
|
||||||
|
|
||||||
class quantized_layer:
|
|
||||||
class QLinear(torch.nn.Linear):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self,input,**kwargs):
|
|
||||||
weight,bias= cast_bias_weight(self,input)
|
|
||||||
return torch.nn.functional.linear(input,weight,bias)
|
|
||||||
|
|
||||||
class QRMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, module):
|
|
||||||
super().__init__()
|
|
||||||
self.module = module
|
|
||||||
|
|
||||||
def forward(self,hidden_states,**kwargs):
|
|
||||||
weight= cast_weight(self.module,hidden_states)
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
|
||||||
hidden_states = hidden_states.to(input_dtype) * weight
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
class QEmbedding(torch.nn.Embedding):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self,input,**kwargs):
|
|
||||||
weight= cast_weight(self,input)
|
|
||||||
return torch.nn.functional.embedding(
|
|
||||||
input, weight, self.padding_idx, self.max_norm,
|
|
||||||
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
|
||||||
|
|
||||||
def replace_layer(model):
|
|
||||||
for name, module in model.named_children():
|
|
||||||
if isinstance(module,quantized_layer.QRMSNorm):
|
|
||||||
continue
|
|
||||||
if isinstance(module, torch.nn.Linear):
|
|
||||||
with init_weights_on_device():
|
|
||||||
new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
|
|
||||||
new_layer.weight = module.weight
|
|
||||||
if module.bias is not None:
|
|
||||||
new_layer.bias = module.bias
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
elif isinstance(module, RMSNorm):
|
|
||||||
if hasattr(module,"quantized"):
|
|
||||||
continue
|
|
||||||
module.quantized= True
|
|
||||||
new_layer = quantized_layer.QRMSNorm(module)
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
elif isinstance(module,torch.nn.Embedding):
|
|
||||||
rows, cols = module.weight.shape
|
|
||||||
new_layer = quantized_layer.QEmbedding(
|
|
||||||
num_embeddings=rows,
|
|
||||||
embedding_dim=cols,
|
|
||||||
_weight=module.weight,
|
|
||||||
# _freeze=module.freeze,
|
|
||||||
padding_idx=module.padding_idx,
|
|
||||||
max_norm=module.max_norm,
|
|
||||||
norm_type=module.norm_type,
|
|
||||||
scale_grad_by_freq=module.scale_grad_by_freq,
|
|
||||||
sparse=module.sparse)
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
else:
|
|
||||||
replace_layer(module)
|
|
||||||
|
|
||||||
replace_layer(self)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxControlNetStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
hash_value = hash_state_dict_keys(state_dict)
|
|
||||||
global_rename_dict = {
|
|
||||||
"context_embedder": "context_embedder",
|
|
||||||
"x_embedder": "x_embedder",
|
|
||||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
|
||||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
|
||||||
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
|
||||||
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
|
||||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
|
||||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
|
||||||
"norm_out.linear": "final_norm_out.linear",
|
|
||||||
"proj_out": "final_proj_out",
|
|
||||||
}
|
|
||||||
rename_dict = {
|
|
||||||
"proj_out": "proj_out",
|
|
||||||
"norm1.linear": "norm1_a.linear",
|
|
||||||
"norm1_context.linear": "norm1_b.linear",
|
|
||||||
"attn.to_q": "attn.a_to_q",
|
|
||||||
"attn.to_k": "attn.a_to_k",
|
|
||||||
"attn.to_v": "attn.a_to_v",
|
|
||||||
"attn.to_out.0": "attn.a_to_out",
|
|
||||||
"attn.add_q_proj": "attn.b_to_q",
|
|
||||||
"attn.add_k_proj": "attn.b_to_k",
|
|
||||||
"attn.add_v_proj": "attn.b_to_v",
|
|
||||||
"attn.to_add_out": "attn.b_to_out",
|
|
||||||
"ff.net.0.proj": "ff_a.0",
|
|
||||||
"ff.net.2": "ff_a.2",
|
|
||||||
"ff_context.net.0.proj": "ff_b.0",
|
|
||||||
"ff_context.net.2": "ff_b.2",
|
|
||||||
"attn.norm_q": "attn.norm_q_a",
|
|
||||||
"attn.norm_k": "attn.norm_k_a",
|
|
||||||
"attn.norm_added_q": "attn.norm_q_b",
|
|
||||||
"attn.norm_added_k": "attn.norm_k_b",
|
|
||||||
}
|
|
||||||
rename_dict_single = {
|
|
||||||
"attn.to_q": "a_to_q",
|
|
||||||
"attn.to_k": "a_to_k",
|
|
||||||
"attn.to_v": "a_to_v",
|
|
||||||
"attn.norm_q": "norm_q_a",
|
|
||||||
"attn.norm_k": "norm_k_a",
|
|
||||||
"norm.linear": "norm.linear",
|
|
||||||
"proj_mlp": "proj_in_besides_attn",
|
|
||||||
"proj_out": "proj_out",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name.endswith(".weight") or name.endswith(".bias"):
|
|
||||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
|
||||||
prefix = name[:-len(suffix)]
|
|
||||||
if prefix in global_rename_dict:
|
|
||||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
|
||||||
elif prefix.startswith("transformer_blocks."):
|
|
||||||
names = prefix.split(".")
|
|
||||||
names[0] = "blocks"
|
|
||||||
middle = ".".join(names[2:])
|
|
||||||
if middle in rename_dict:
|
|
||||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
|
||||||
state_dict_[name_] = param
|
|
||||||
elif prefix.startswith("single_transformer_blocks."):
|
|
||||||
names = prefix.split(".")
|
|
||||||
names[0] = "single_blocks"
|
|
||||||
middle = ".".join(names[2:])
|
|
||||||
if middle in rename_dict_single:
|
|
||||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
|
||||||
state_dict_[name_] = param
|
|
||||||
else:
|
|
||||||
state_dict_[name] = param
|
|
||||||
else:
|
|
||||||
state_dict_[name] = param
|
|
||||||
for name in list(state_dict_.keys()):
|
|
||||||
if ".proj_in_besides_attn." in name:
|
|
||||||
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
|
||||||
param = torch.concat([
|
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
|
||||||
state_dict_[name],
|
|
||||||
], dim=0)
|
|
||||||
state_dict_[name_] = param
|
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
|
||||||
state_dict_.pop(name)
|
|
||||||
for name in list(state_dict_.keys()):
|
|
||||||
for component in ["a", "b"]:
|
|
||||||
if f".{component}_to_q." in name:
|
|
||||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
|
||||||
param = torch.concat([
|
|
||||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
|
||||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
|
||||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
|
||||||
], dim=0)
|
|
||||||
state_dict_[name_] = param
|
|
||||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
|
||||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
|
||||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
|
||||||
if hash_value == "78d18b9101345ff695f312e7e62538c0":
|
|
||||||
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
|
|
||||||
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
|
|
||||||
extra_kwargs = {"num_single_blocks": 0}
|
|
||||||
elif hash_value == "52357cb26250681367488a8954c271e8":
|
|
||||||
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
|
||||||
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
|
||||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
|
||||||
else:
|
|
||||||
extra_kwargs = {}
|
|
||||||
return state_dict_, extra_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
@@ -1,739 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
|
||||||
from einops import rearrange
|
|
||||||
from .tiler import TileWorker
|
|
||||||
from .utils import init_weights_on_device
|
|
||||||
|
|
||||||
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
|
||||||
batch_size, num_tokens = hidden_states.shape[0:2]
|
|
||||||
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
|
||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
|
|
||||||
hidden_states = hidden_states + scale * ip_hidden_states
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class RoPEEmbedding(torch.nn.Module):
|
|
||||||
def __init__(self, dim, theta, axes_dim):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.theta = theta
|
|
||||||
self.axes_dim = axes_dim
|
|
||||||
|
|
||||||
|
|
||||||
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
|
||||||
assert dim % 2 == 0, "The dimension must be even."
|
|
||||||
|
|
||||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
|
||||||
omega = 1.0 / (theta**scale)
|
|
||||||
|
|
||||||
batch_size, seq_length = pos.shape
|
|
||||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
|
||||||
cos_out = torch.cos(out)
|
|
||||||
sin_out = torch.sin(out)
|
|
||||||
|
|
||||||
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
|
||||||
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
|
||||||
return out.float()
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, ids):
|
|
||||||
n_axes = ids.shape[-1]
|
|
||||||
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
|
||||||
return emb.unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxJointAttention(torch.nn.Module):
|
|
||||||
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.only_out_a = only_out_a
|
|
||||||
|
|
||||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
|
||||||
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
|
||||||
|
|
||||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
|
||||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
|
||||||
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
|
|
||||||
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
|
|
||||||
|
|
||||||
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
|
||||||
if not only_out_a:
|
|
||||||
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(self, xq, xk, freqs_cis):
|
|
||||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
|
||||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
|
||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
|
||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
|
||||||
batch_size = hidden_states_a.shape[0]
|
|
||||||
|
|
||||||
# Part A
|
|
||||||
qkv_a = self.a_to_qkv(hidden_states_a)
|
|
||||||
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
|
|
||||||
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
|
||||||
|
|
||||||
# Part B
|
|
||||||
qkv_b = self.b_to_qkv(hidden_states_b)
|
|
||||||
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
|
|
||||||
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
|
|
||||||
|
|
||||||
q = torch.concat([q_b, q_a], dim=2)
|
|
||||||
k = torch.concat([k_b, k_a], dim=2)
|
|
||||||
v = torch.concat([v_b, v_a], dim=2)
|
|
||||||
|
|
||||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
|
||||||
|
|
||||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
|
||||||
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
|
||||||
if ipadapter_kwargs_list is not None:
|
|
||||||
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
|
|
||||||
hidden_states_a = self.a_to_out(hidden_states_a)
|
|
||||||
if self.only_out_a:
|
|
||||||
return hidden_states_a
|
|
||||||
else:
|
|
||||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
|
||||||
return hidden_states_a, hidden_states_b
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxJointTransformerBlock(torch.nn.Module):
|
|
||||||
def __init__(self, dim, num_attention_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1_a = AdaLayerNorm(dim)
|
|
||||||
self.norm1_b = AdaLayerNorm(dim)
|
|
||||||
|
|
||||||
self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
|
||||||
|
|
||||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.ff_a = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim, dim*4),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(dim*4, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.ff_b = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim, dim*4),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(dim*4, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
|
||||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
|
||||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
|
||||||
|
|
||||||
# Attention
|
|
||||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
|
||||||
|
|
||||||
# Part A
|
|
||||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
|
||||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
|
||||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
|
||||||
|
|
||||||
# Part B
|
|
||||||
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
|
||||||
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
|
||||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
|
||||||
|
|
||||||
return hidden_states_a, hidden_states_b
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxSingleAttention(torch.nn.Module):
|
|
||||||
def __init__(self, dim_a, dim_b, num_heads, head_dim):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
|
|
||||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
|
||||||
|
|
||||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
|
||||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(self, xq, xk, freqs_cis):
|
|
||||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
|
||||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
|
||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
|
||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, image_rotary_emb):
|
|
||||||
batch_size = hidden_states.shape[0]
|
|
||||||
|
|
||||||
qkv_a = self.a_to_qkv(hidden_states)
|
|
||||||
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
q_a, k_a, v = qkv_a.chunk(3, dim=1)
|
|
||||||
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
|
||||||
|
|
||||||
q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
|
|
||||||
|
|
||||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AdaLayerNormSingle(torch.nn.Module):
|
|
||||||
def __init__(self, dim):
|
|
||||||
super().__init__()
|
|
||||||
self.silu = torch.nn.SiLU()
|
|
||||||
self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
|
|
||||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, emb):
|
|
||||||
emb = self.linear(self.silu(emb))
|
|
||||||
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
|
||||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
|
||||||
return x, gate_msa
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxSingleTransformerBlock(torch.nn.Module):
|
|
||||||
def __init__(self, dim, num_attention_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_attention_heads
|
|
||||||
self.head_dim = dim // num_attention_heads
|
|
||||||
self.dim = dim
|
|
||||||
|
|
||||||
self.norm = AdaLayerNormSingle(dim)
|
|
||||||
self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
|
|
||||||
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
|
|
||||||
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
|
|
||||||
|
|
||||||
self.proj_out = torch.nn.Linear(dim * 5, dim)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(self, xq, xk, freqs_cis):
|
|
||||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
|
||||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
|
||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
|
||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
|
||||||
|
|
||||||
|
|
||||||
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
|
||||||
batch_size = hidden_states.shape[0]
|
|
||||||
|
|
||||||
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
q, k, v = qkv.chunk(3, dim=1)
|
|
||||||
q, k = self.norm_q_a(q), self.norm_k_a(k)
|
|
||||||
|
|
||||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
|
||||||
|
|
||||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
|
||||||
if ipadapter_kwargs_list is not None:
|
|
||||||
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
|
||||||
residual = hidden_states_a
|
|
||||||
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
|
||||||
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
|
||||||
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
|
||||||
|
|
||||||
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
|
||||||
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
|
||||||
|
|
||||||
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
|
||||||
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
|
|
||||||
hidden_states_a = residual + hidden_states_a
|
|
||||||
|
|
||||||
return hidden_states_a, hidden_states_b
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AdaLayerNormContinuous(torch.nn.Module):
|
|
||||||
def __init__(self, dim):
|
|
||||||
super().__init__()
|
|
||||||
self.silu = torch.nn.SiLU()
|
|
||||||
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
|
|
||||||
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
|
||||||
|
|
||||||
def forward(self, x, conditioning):
|
|
||||||
emb = self.linear(self.silu(conditioning))
|
|
||||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
|
||||||
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxDiT(torch.nn.Module):
|
|
||||||
def __init__(self, disable_guidance_embedder=False):
|
|
||||||
super().__init__()
|
|
||||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
|
||||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
|
||||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
|
||||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
|
||||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
|
||||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
|
||||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
|
||||||
|
|
||||||
self.final_norm_out = AdaLayerNormContinuous(3072)
|
|
||||||
self.final_proj_out = torch.nn.Linear(3072, 64)
|
|
||||||
|
|
||||||
|
|
||||||
def patchify(self, hidden_states):
|
|
||||||
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def unpatchify(self, hidden_states, height, width):
|
|
||||||
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_image_ids(self, latents):
|
|
||||||
batch_size, _, height, width = latents.shape
|
|
||||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
|
||||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
|
||||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
|
||||||
|
|
||||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
|
||||||
|
|
||||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
|
||||||
latent_image_ids = latent_image_ids.reshape(
|
|
||||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
|
||||||
)
|
|
||||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
|
||||||
|
|
||||||
return latent_image_ids
|
|
||||||
|
|
||||||
|
|
||||||
def tiled_forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
|
||||||
tile_size=128, tile_stride=64,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
|
||||||
hidden_states = TileWorker().tiled_forward(
|
|
||||||
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
|
|
||||||
hidden_states,
|
|
||||||
tile_size,
|
|
||||||
tile_stride,
|
|
||||||
tile_device=hidden_states.device,
|
|
||||||
tile_dtype=hidden_states.dtype
|
|
||||||
)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
|
|
||||||
N = len(entity_masks)
|
|
||||||
batch_size = entity_masks[0].shape[0]
|
|
||||||
total_seq_len = N * prompt_seq_len + image_seq_len
|
|
||||||
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
|
|
||||||
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
|
||||||
|
|
||||||
image_start = N * prompt_seq_len
|
|
||||||
image_end = N * prompt_seq_len + image_seq_len
|
|
||||||
# prompt-image mask
|
|
||||||
for i in range(N):
|
|
||||||
prompt_start = i * prompt_seq_len
|
|
||||||
prompt_end = (i + 1) * prompt_seq_len
|
|
||||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
|
||||||
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
|
|
||||||
# prompt update with image
|
|
||||||
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
|
||||||
# image update with prompt
|
|
||||||
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
|
||||||
# prompt-prompt mask
|
|
||||||
for i in range(N):
|
|
||||||
for j in range(N):
|
|
||||||
if i != j:
|
|
||||||
prompt_start_i = i * prompt_seq_len
|
|
||||||
prompt_end_i = (i + 1) * prompt_seq_len
|
|
||||||
prompt_start_j = j * prompt_seq_len
|
|
||||||
prompt_end_j = (j + 1) * prompt_seq_len
|
|
||||||
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
|
|
||||||
|
|
||||||
attention_mask = attention_mask.float()
|
|
||||||
attention_mask[attention_mask == 0] = float('-inf')
|
|
||||||
attention_mask[attention_mask == 1] = 0
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
|
|
||||||
repeat_dim = hidden_states.shape[1]
|
|
||||||
max_masks = 0
|
|
||||||
attention_mask = None
|
|
||||||
prompt_embs = [prompt_emb]
|
|
||||||
if entity_masks is not None:
|
|
||||||
# entity_masks
|
|
||||||
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
|
|
||||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
|
||||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
|
||||||
# global mask
|
|
||||||
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
|
||||||
entity_masks = entity_masks + [global_mask] # append global to last
|
|
||||||
# attention mask
|
|
||||||
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
|
|
||||||
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
|
||||||
attention_mask = attention_mask.unsqueeze(1)
|
|
||||||
# embds: n_masks * b * seq * d
|
|
||||||
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
|
|
||||||
prompt_embs = local_embs + prompt_embs # append global to last
|
|
||||||
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
|
|
||||||
prompt_emb = torch.cat(prompt_embs, dim=1)
|
|
||||||
|
|
||||||
# positional embedding
|
|
||||||
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
|
|
||||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
|
||||||
return prompt_emb, image_rotary_emb, attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
|
||||||
tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
|
|
||||||
use_gradient_checkpointing=False,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
if tiled:
|
|
||||||
return self.tiled_forward(
|
|
||||||
hidden_states,
|
|
||||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
|
||||||
tile_size=tile_size, tile_stride=tile_stride,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if image_ids is None:
|
|
||||||
image_ids = self.prepare_image_ids(hidden_states)
|
|
||||||
|
|
||||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
|
||||||
if self.guidance_embedder is not None:
|
|
||||||
guidance = guidance * 1000
|
|
||||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
|
||||||
|
|
||||||
height, width = hidden_states.shape[-2:]
|
|
||||||
hidden_states = self.patchify(hidden_states)
|
|
||||||
hidden_states = self.x_embedder(hidden_states)
|
|
||||||
|
|
||||||
if entity_prompt_emb is not None and entity_masks is not None:
|
|
||||||
prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
|
||||||
else:
|
|
||||||
prompt_emb = self.context_embedder(prompt_emb)
|
|
||||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
|
||||||
attention_mask = None
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
if self.training and use_gradient_checkpointing:
|
|
||||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
||||||
for block in self.single_blocks:
|
|
||||||
if self.training and use_gradient_checkpointing:
|
|
||||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
|
||||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
|
||||||
|
|
||||||
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
|
||||||
hidden_states = self.final_proj_out(hidden_states)
|
|
||||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def quantize(self):
|
|
||||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
|
||||||
if device is None or weight.device == device:
|
|
||||||
if not copy:
|
|
||||||
if dtype is None or weight.dtype == dtype:
|
|
||||||
return weight
|
|
||||||
return weight.to(dtype=dtype, copy=copy)
|
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
r.copy_(weight)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def cast_weight(s, input=None, dtype=None, device=None):
|
|
||||||
if input is not None:
|
|
||||||
if dtype is None:
|
|
||||||
dtype = input.dtype
|
|
||||||
if device is None:
|
|
||||||
device = input.device
|
|
||||||
weight = cast_to(s.weight, dtype, device)
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|
||||||
if input is not None:
|
|
||||||
if dtype is None:
|
|
||||||
dtype = input.dtype
|
|
||||||
if bias_dtype is None:
|
|
||||||
bias_dtype = dtype
|
|
||||||
if device is None:
|
|
||||||
device = input.device
|
|
||||||
bias = None
|
|
||||||
weight = cast_to(s.weight, dtype, device)
|
|
||||||
bias = cast_to(s.bias, bias_dtype, device)
|
|
||||||
return weight, bias
|
|
||||||
|
|
||||||
class quantized_layer:
|
|
||||||
class Linear(torch.nn.Linear):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self,input,**kwargs):
|
|
||||||
weight,bias= cast_bias_weight(self,input)
|
|
||||||
return torch.nn.functional.linear(input,weight,bias)
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, module):
|
|
||||||
super().__init__()
|
|
||||||
self.module = module
|
|
||||||
|
|
||||||
def forward(self,hidden_states,**kwargs):
|
|
||||||
weight= cast_weight(self.module,hidden_states)
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
|
||||||
hidden_states = hidden_states.to(input_dtype) * weight
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def replace_layer(model):
|
|
||||||
for name, module in model.named_children():
|
|
||||||
if isinstance(module, torch.nn.Linear):
|
|
||||||
with init_weights_on_device():
|
|
||||||
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
|
|
||||||
new_layer.weight = module.weight
|
|
||||||
if module.bias is not None:
|
|
||||||
new_layer.bias = module.bias
|
|
||||||
# del module
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
elif isinstance(module, RMSNorm):
|
|
||||||
if hasattr(module,"quantized"):
|
|
||||||
continue
|
|
||||||
module.quantized= True
|
|
||||||
new_layer = quantized_layer.RMSNorm(module)
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
else:
|
|
||||||
replace_layer(module)
|
|
||||||
|
|
||||||
replace_layer(self)
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return FluxDiTStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class FluxDiTStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
global_rename_dict = {
|
|
||||||
"context_embedder": "context_embedder",
|
|
||||||
"x_embedder": "x_embedder",
|
|
||||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
|
||||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
|
||||||
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
|
||||||
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
|
||||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
|
||||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
|
||||||
"norm_out.linear": "final_norm_out.linear",
|
|
||||||
"proj_out": "final_proj_out",
|
|
||||||
}
|
|
||||||
rename_dict = {
|
|
||||||
"proj_out": "proj_out",
|
|
||||||
"norm1.linear": "norm1_a.linear",
|
|
||||||
"norm1_context.linear": "norm1_b.linear",
|
|
||||||
"attn.to_q": "attn.a_to_q",
|
|
||||||
"attn.to_k": "attn.a_to_k",
|
|
||||||
"attn.to_v": "attn.a_to_v",
|
|
||||||
"attn.to_out.0": "attn.a_to_out",
|
|
||||||
"attn.add_q_proj": "attn.b_to_q",
|
|
||||||
"attn.add_k_proj": "attn.b_to_k",
|
|
||||||
"attn.add_v_proj": "attn.b_to_v",
|
|
||||||
"attn.to_add_out": "attn.b_to_out",
|
|
||||||
"ff.net.0.proj": "ff_a.0",
|
|
||||||
"ff.net.2": "ff_a.2",
|
|
||||||
"ff_context.net.0.proj": "ff_b.0",
|
|
||||||
"ff_context.net.2": "ff_b.2",
|
|
||||||
"attn.norm_q": "attn.norm_q_a",
|
|
||||||
"attn.norm_k": "attn.norm_k_a",
|
|
||||||
"attn.norm_added_q": "attn.norm_q_b",
|
|
||||||
"attn.norm_added_k": "attn.norm_k_b",
|
|
||||||
}
|
|
||||||
rename_dict_single = {
|
|
||||||
"attn.to_q": "a_to_q",
|
|
||||||
"attn.to_k": "a_to_k",
|
|
||||||
"attn.to_v": "a_to_v",
|
|
||||||
"attn.norm_q": "norm_q_a",
|
|
||||||
"attn.norm_k": "norm_k_a",
|
|
||||||
"norm.linear": "norm.linear",
|
|
||||||
"proj_mlp": "proj_in_besides_attn",
|
|
||||||
"proj_out": "proj_out",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name.endswith(".weight") or name.endswith(".bias"):
|
|
||||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
|
||||||
prefix = name[:-len(suffix)]
|
|
||||||
if prefix in global_rename_dict:
|
|
||||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
|
||||||
elif prefix.startswith("transformer_blocks."):
|
|
||||||
names = prefix.split(".")
|
|
||||||
names[0] = "blocks"
|
|
||||||
middle = ".".join(names[2:])
|
|
||||||
if middle in rename_dict:
|
|
||||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
|
||||||
state_dict_[name_] = param
|
|
||||||
elif prefix.startswith("single_transformer_blocks."):
|
|
||||||
names = prefix.split(".")
|
|
||||||
names[0] = "single_blocks"
|
|
||||||
middle = ".".join(names[2:])
|
|
||||||
if middle in rename_dict_single:
|
|
||||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
|
||||||
state_dict_[name_] = param
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
for name in list(state_dict_.keys()):
|
|
||||||
if ".proj_in_besides_attn." in name:
|
|
||||||
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
|
||||||
param = torch.concat([
|
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
|
||||||
state_dict_[name],
|
|
||||||
], dim=0)
|
|
||||||
state_dict_[name_] = param
|
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
|
||||||
state_dict_.pop(name)
|
|
||||||
for name in list(state_dict_.keys()):
|
|
||||||
for component in ["a", "b"]:
|
|
||||||
if f".{component}_to_q." in name:
|
|
||||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
|
||||||
param = torch.concat([
|
|
||||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
|
||||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
|
||||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
|
||||||
], dim=0)
|
|
||||||
state_dict_[name_] = param
|
|
||||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
|
||||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
|
||||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
|
||||||
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
|
||||||
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
|
|
||||||
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
|
|
||||||
"txt_in.bias": "context_embedder.bias",
|
|
||||||
"txt_in.weight": "context_embedder.weight",
|
|
||||||
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
|
|
||||||
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
|
|
||||||
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
|
|
||||||
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
|
|
||||||
"final_layer.linear.bias": "final_proj_out.bias",
|
|
||||||
"final_layer.linear.weight": "final_proj_out.weight",
|
|
||||||
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
|
|
||||||
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
|
|
||||||
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
|
|
||||||
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
|
|
||||||
"img_in.bias": "x_embedder.bias",
|
|
||||||
"img_in.weight": "x_embedder.weight",
|
|
||||||
"final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
|
|
||||||
"final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
|
|
||||||
}
|
|
||||||
suffix_rename_dict = {
|
|
||||||
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
|
|
||||||
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
|
|
||||||
"img_attn.proj.bias": "attn.a_to_out.bias",
|
|
||||||
"img_attn.proj.weight": "attn.a_to_out.weight",
|
|
||||||
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
|
|
||||||
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
|
|
||||||
"img_mlp.0.bias": "ff_a.0.bias",
|
|
||||||
"img_mlp.0.weight": "ff_a.0.weight",
|
|
||||||
"img_mlp.2.bias": "ff_a.2.bias",
|
|
||||||
"img_mlp.2.weight": "ff_a.2.weight",
|
|
||||||
"img_mod.lin.bias": "norm1_a.linear.bias",
|
|
||||||
"img_mod.lin.weight": "norm1_a.linear.weight",
|
|
||||||
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
|
|
||||||
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
|
|
||||||
"txt_attn.proj.bias": "attn.b_to_out.bias",
|
|
||||||
"txt_attn.proj.weight": "attn.b_to_out.weight",
|
|
||||||
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
|
|
||||||
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
|
|
||||||
"txt_mlp.0.bias": "ff_b.0.bias",
|
|
||||||
"txt_mlp.0.weight": "ff_b.0.weight",
|
|
||||||
"txt_mlp.2.bias": "ff_b.2.bias",
|
|
||||||
"txt_mlp.2.weight": "ff_b.2.weight",
|
|
||||||
"txt_mod.lin.bias": "norm1_b.linear.bias",
|
|
||||||
"txt_mod.lin.weight": "norm1_b.linear.weight",
|
|
||||||
|
|
||||||
"linear1.bias": "to_qkv_mlp.bias",
|
|
||||||
"linear1.weight": "to_qkv_mlp.weight",
|
|
||||||
"linear2.bias": "proj_out.bias",
|
|
||||||
"linear2.weight": "proj_out.weight",
|
|
||||||
"modulation.lin.bias": "norm.linear.bias",
|
|
||||||
"modulation.lin.weight": "norm.linear.weight",
|
|
||||||
"norm.key_norm.scale": "norm_k_a.weight",
|
|
||||||
"norm.query_norm.scale": "norm_q_a.weight",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name.startswith("model.diffusion_model."):
|
|
||||||
name = name[len("model.diffusion_model."):]
|
|
||||||
names = name.split(".")
|
|
||||||
if name in rename_dict:
|
|
||||||
rename = rename_dict[name]
|
|
||||||
if name.startswith("final_layer.adaLN_modulation.1."):
|
|
||||||
param = torch.concat([param[3072:], param[:3072]], dim=0)
|
|
||||||
state_dict_[rename] = param
|
|
||||||
elif names[0] == "double_blocks":
|
|
||||||
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
|
||||||
state_dict_[rename] = param
|
|
||||||
elif names[0] == "single_blocks":
|
|
||||||
if ".".join(names[2:]) in suffix_rename_dict:
|
|
||||||
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
|
||||||
state_dict_[rename] = param
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
|
||||||
return state_dict_, {"disable_guidance_embedder": True}
|
|
||||||
else:
|
|
||||||
return state_dict_
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
from .svd_image_encoder import SVDImageEncoder
|
|
||||||
from .sd3_dit import RMSNorm
|
|
||||||
from transformers import CLIPImageProcessor
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class MLPProjModel(torch.nn.Module):
|
|
||||||
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.cross_attention_dim = cross_attention_dim
|
|
||||||
self.num_tokens = num_tokens
|
|
||||||
|
|
||||||
self.proj = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
|
||||||
torch.nn.GELU(),
|
|
||||||
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
|
||||||
)
|
|
||||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
|
||||||
|
|
||||||
def forward(self, id_embeds):
|
|
||||||
x = self.proj(id_embeds)
|
|
||||||
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
|
||||||
x = self.norm(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class IpAdapterModule(torch.nn.Module):
|
|
||||||
def __init__(self, num_attention_heads, attention_head_dim, input_dim):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_attention_heads
|
|
||||||
self.head_dim = attention_head_dim
|
|
||||||
output_dim = num_attention_heads * attention_head_dim
|
|
||||||
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
|
||||||
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
|
||||||
self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
batch_size = hidden_states.shape[0]
|
|
||||||
# ip_k
|
|
||||||
ip_k = self.to_k_ip(hidden_states)
|
|
||||||
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
ip_k = self.norm_added_k(ip_k)
|
|
||||||
# ip_v
|
|
||||||
ip_v = self.to_v_ip(hidden_states)
|
|
||||||
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
return ip_k, ip_v
|
|
||||||
|
|
||||||
|
|
||||||
class FluxIpAdapter(torch.nn.Module):
|
|
||||||
def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
|
|
||||||
super().__init__()
|
|
||||||
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
|
|
||||||
self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
|
|
||||||
self.set_adapter()
|
|
||||||
|
|
||||||
def set_adapter(self):
|
|
||||||
self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
|
|
||||||
|
|
||||||
def forward(self, hidden_states, scale=1.0):
|
|
||||||
hidden_states = self.image_proj(hidden_states)
|
|
||||||
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
|
||||||
ip_kv_dict = {}
|
|
||||||
for block_id in self.call_block_id:
|
|
||||||
ipadapter_id = self.call_block_id[block_id]
|
|
||||||
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
|
||||||
ip_kv_dict[block_id] = {
|
|
||||||
"ip_k": ip_k,
|
|
||||||
"ip_v": ip_v,
|
|
||||||
"scale": scale
|
|
||||||
}
|
|
||||||
return ip_kv_dict
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return FluxIpAdapterStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class FluxIpAdapterStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {}
|
|
||||||
for name in state_dict["ip_adapter"]:
|
|
||||||
name_ = 'ipadapter_modules.' + name
|
|
||||||
state_dict_[name_] = state_dict["ip_adapter"][name]
|
|
||||||
for name in state_dict["image_proj"]:
|
|
||||||
name_ = "image_proj." + name
|
|
||||||
state_dict_[name_] = state_dict["image_proj"][name]
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
import torch
|
|
||||||
from transformers import T5EncoderModel, T5Config
|
|
||||||
from .sd_text_encoder import SDTextEncoder
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxTextEncoder2(T5EncoderModel):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__(config)
|
|
||||||
self.eval()
|
|
||||||
|
|
||||||
def forward(self, input_ids):
|
|
||||||
outputs = super().forward(input_ids=input_ids)
|
|
||||||
prompt_emb = outputs.last_hidden_state
|
|
||||||
return prompt_emb
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return FluxTextEncoder2StateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxTextEncoder2StateDictConverter():
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = state_dict
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
@@ -1,303 +0,0 @@
|
|||||||
from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
|
|
||||||
from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
|
|
||||||
|
|
||||||
|
|
||||||
class FluxVAEEncoder(SD3VAEEncoder):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.scaling_factor = 0.3611
|
|
||||||
self.shift_factor = 0.1159
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return FluxVAEEncoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class FluxVAEDecoder(SD3VAEDecoder):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.scaling_factor = 0.3611
|
|
||||||
self.shift_factor = 0.1159
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return FluxVAEDecoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"encoder.conv_in.bias": "conv_in.bias",
|
|
||||||
"encoder.conv_in.weight": "conv_in.weight",
|
|
||||||
"encoder.conv_out.bias": "conv_out.bias",
|
|
||||||
"encoder.conv_out.weight": "conv_out.weight",
|
|
||||||
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
|
||||||
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
|
||||||
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
|
||||||
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
|
||||||
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
|
||||||
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
|
||||||
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
|
||||||
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
|
||||||
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
|
||||||
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
|
||||||
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
|
||||||
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
|
||||||
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
|
||||||
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
|
||||||
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
|
||||||
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
|
||||||
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
|
||||||
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
|
||||||
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
|
||||||
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
|
||||||
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
|
||||||
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
|
||||||
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
|
||||||
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
|
||||||
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
|
||||||
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
|
||||||
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
|
||||||
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
|
||||||
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
|
||||||
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
|
||||||
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
|
||||||
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
|
||||||
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
|
||||||
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
|
||||||
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
|
||||||
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
|
||||||
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
|
||||||
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
|
||||||
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
|
||||||
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
|
||||||
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
|
||||||
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
|
||||||
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
|
||||||
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
|
||||||
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
|
||||||
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
|
||||||
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
|
||||||
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
|
||||||
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
|
||||||
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
|
||||||
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
|
||||||
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
|
||||||
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
|
||||||
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
|
||||||
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
|
||||||
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
|
||||||
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
|
||||||
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
|
||||||
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
|
||||||
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
|
||||||
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
|
||||||
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
|
||||||
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
|
||||||
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
|
||||||
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
|
||||||
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
|
||||||
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
|
||||||
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
|
||||||
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
|
||||||
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
|
||||||
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
|
||||||
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
|
||||||
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
|
||||||
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
|
||||||
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
|
||||||
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
|
||||||
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
|
||||||
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
|
||||||
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
|
||||||
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
|
||||||
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
|
||||||
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
|
||||||
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
|
||||||
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
|
||||||
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
|
||||||
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
|
||||||
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
|
||||||
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
|
||||||
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
|
||||||
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
|
||||||
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
|
||||||
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
|
||||||
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
|
||||||
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
|
||||||
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
|
||||||
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
|
||||||
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
|
||||||
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
|
||||||
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
|
||||||
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
|
||||||
"encoder.norm_out.bias": "conv_norm_out.bias",
|
|
||||||
"encoder.norm_out.weight": "conv_norm_out.weight",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name in state_dict:
|
|
||||||
if name in rename_dict:
|
|
||||||
param = state_dict[name]
|
|
||||||
if "transformer_blocks" in rename_dict[name]:
|
|
||||||
param = param.squeeze()
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"decoder.conv_in.bias": "conv_in.bias",
|
|
||||||
"decoder.conv_in.weight": "conv_in.weight",
|
|
||||||
"decoder.conv_out.bias": "conv_out.bias",
|
|
||||||
"decoder.conv_out.weight": "conv_out.weight",
|
|
||||||
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
|
||||||
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
|
||||||
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
|
||||||
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
|
||||||
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
|
||||||
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
|
||||||
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
|
||||||
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
|
||||||
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
|
||||||
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
|
||||||
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
|
||||||
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
|
||||||
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
|
||||||
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
|
||||||
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
|
||||||
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
|
||||||
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
|
||||||
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
|
||||||
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
|
||||||
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
|
||||||
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
|
||||||
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
|
||||||
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
|
||||||
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
|
||||||
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
|
||||||
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
|
||||||
"decoder.norm_out.bias": "conv_norm_out.bias",
|
|
||||||
"decoder.norm_out.weight": "conv_norm_out.weight",
|
|
||||||
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
|
||||||
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
|
||||||
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
|
||||||
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
|
||||||
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
|
||||||
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
|
||||||
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
|
||||||
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
|
||||||
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
|
||||||
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
|
||||||
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
|
||||||
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
|
||||||
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
|
||||||
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
|
||||||
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
|
||||||
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
|
||||||
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
|
||||||
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
|
||||||
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
|
||||||
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
|
||||||
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
|
||||||
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
|
||||||
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
|
||||||
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
|
||||||
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
|
||||||
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
|
||||||
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
|
||||||
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
|
||||||
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
|
||||||
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
|
||||||
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
|
||||||
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
|
||||||
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
|
||||||
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
|
||||||
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
|
||||||
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
|
||||||
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
|
||||||
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
|
||||||
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
|
||||||
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
|
||||||
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
|
||||||
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
|
||||||
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
|
||||||
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
|
||||||
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
|
||||||
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
|
||||||
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
|
||||||
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
|
||||||
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
|
||||||
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
|
||||||
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
|
||||||
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
|
||||||
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
|
||||||
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
|
||||||
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
|
||||||
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
|
||||||
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
|
||||||
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
|
||||||
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
|
||||||
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
|
||||||
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
|
||||||
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
|
||||||
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
|
||||||
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
|
||||||
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
|
||||||
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
|
||||||
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
|
||||||
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
|
||||||
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
|
||||||
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
|
||||||
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
|
||||||
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
|
||||||
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
|
||||||
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
|
||||||
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
|
||||||
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
|
||||||
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
|
||||||
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
|
||||||
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
|
||||||
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
|
||||||
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
|
||||||
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
|
||||||
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
|
||||||
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
|
||||||
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
|
||||||
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
|
||||||
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
|
||||||
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
|
||||||
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
|
||||||
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
|
||||||
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
|
||||||
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
|
||||||
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
|
||||||
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
|
||||||
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
|
||||||
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
|
||||||
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
|
||||||
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
|
||||||
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
|
||||||
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
|
||||||
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
|
||||||
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
|
||||||
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
|
||||||
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
|
||||||
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
|
||||||
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name in state_dict:
|
|
||||||
if name in rename_dict:
|
|
||||||
param = state_dict[name]
|
|
||||||
if "transformer_blocks" in rename_dict[name]:
|
|
||||||
param = param.squeeze()
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
return state_dict_
|
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
from .attention import Attention
|
from .attention import Attention
|
||||||
|
from .tiler import TileWorker
|
||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
@@ -398,8 +399,7 @@ class HunyuanDiT(torch.nn.Module):
|
|||||||
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanDiTStateDictConverter()
|
return HunyuanDiTStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -79,8 +79,7 @@ class HunyuanDiTCLIPTextEncoder(BertModel):
|
|||||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||||
return prompt_emb
|
return prompt_emb
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
@@ -132,8 +131,7 @@ class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
|||||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||||
return prompt_emb
|
return prompt_emb
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanDiTT5TextEncoderStateDictConverter()
|
return HunyuanDiTT5TextEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,885 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
|
||||||
from .utils import init_weights_on_device
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import Union, Tuple, List
|
|
||||||
|
|
||||||
|
|
||||||
def HunyuanVideoRope(latents):
|
|
||||||
def _to_tuple(x, dim=2):
|
|
||||||
if isinstance(x, int):
|
|
||||||
return (x,) * dim
|
|
||||||
elif len(x) == dim:
|
|
||||||
return x
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_meshgrid_nd(start, *args, dim=2):
|
|
||||||
"""
|
|
||||||
Get n-D meshgrid with start, stop and num.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
|
||||||
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
|
||||||
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
|
||||||
n-tuples.
|
|
||||||
*args: See above.
|
|
||||||
dim (int): Dimension of the meshgrid. Defaults to 2.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
grid (np.ndarray): [dim, ...]
|
|
||||||
"""
|
|
||||||
if len(args) == 0:
|
|
||||||
# start is grid_size
|
|
||||||
num = _to_tuple(start, dim=dim)
|
|
||||||
start = (0,) * dim
|
|
||||||
stop = num
|
|
||||||
elif len(args) == 1:
|
|
||||||
# start is start, args[0] is stop, step is 1
|
|
||||||
start = _to_tuple(start, dim=dim)
|
|
||||||
stop = _to_tuple(args[0], dim=dim)
|
|
||||||
num = [stop[i] - start[i] for i in range(dim)]
|
|
||||||
elif len(args) == 2:
|
|
||||||
# start is start, args[0] is stop, args[1] is num
|
|
||||||
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
|
||||||
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
|
||||||
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
|
||||||
else:
|
|
||||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
|
||||||
|
|
||||||
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
|
||||||
axis_grid = []
|
|
||||||
for i in range(dim):
|
|
||||||
a, b, n = start[i], stop[i], num[i]
|
|
||||||
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
|
||||||
axis_grid.append(g)
|
|
||||||
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
|
||||||
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
|
||||||
|
|
||||||
return grid
|
|
||||||
|
|
||||||
|
|
||||||
def get_1d_rotary_pos_embed(
|
|
||||||
dim: int,
|
|
||||||
pos: Union[torch.FloatTensor, int],
|
|
||||||
theta: float = 10000.0,
|
|
||||||
use_real: bool = False,
|
|
||||||
theta_rescale_factor: float = 1.0,
|
|
||||||
interpolation_factor: float = 1.0,
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
|
||||||
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
|
||||||
|
|
||||||
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
|
||||||
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
|
||||||
The returned tensor contains complex values in complex64 data type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (int): Dimension of the frequency tensor.
|
|
||||||
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
|
||||||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
|
||||||
use_real (bool, optional): If True, return real part and imaginary part separately.
|
|
||||||
Otherwise, return complex numbers.
|
|
||||||
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
|
||||||
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
|
||||||
"""
|
|
||||||
if isinstance(pos, int):
|
|
||||||
pos = torch.arange(pos).float()
|
|
||||||
|
|
||||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
|
||||||
# has some connection to NTK literature
|
|
||||||
if theta_rescale_factor != 1.0:
|
|
||||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
|
||||||
|
|
||||||
freqs = 1.0 / (
|
|
||||||
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
|
||||||
) # [D/2]
|
|
||||||
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
|
|
||||||
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
|
||||||
if use_real:
|
|
||||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
|
||||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
|
||||||
return freqs_cos, freqs_sin
|
|
||||||
else:
|
|
||||||
freqs_cis = torch.polar(
|
|
||||||
torch.ones_like(freqs), freqs
|
|
||||||
) # complex64 # [S, D/2]
|
|
||||||
return freqs_cis
|
|
||||||
|
|
||||||
|
|
||||||
def get_nd_rotary_pos_embed(
|
|
||||||
rope_dim_list,
|
|
||||||
start,
|
|
||||||
*args,
|
|
||||||
theta=10000.0,
|
|
||||||
use_real=False,
|
|
||||||
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
|
||||||
interpolation_factor: Union[float, List[float]] = 1.0,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
|
||||||
sum(rope_dim_list) should equal to head_dim of attention layer.
|
|
||||||
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
|
||||||
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
|
||||||
*args: See above.
|
|
||||||
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
|
||||||
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
|
||||||
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
|
||||||
part and an imaginary part separately.
|
|
||||||
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
pos_embed (torch.Tensor): [HW, D/2]
|
|
||||||
"""
|
|
||||||
|
|
||||||
grid = get_meshgrid_nd(
|
|
||||||
start, *args, dim=len(rope_dim_list)
|
|
||||||
) # [3, W, H, D] / [2, W, H]
|
|
||||||
|
|
||||||
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
|
||||||
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
|
||||||
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
|
||||||
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
|
||||||
assert len(theta_rescale_factor) == len(
|
|
||||||
rope_dim_list
|
|
||||||
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
|
||||||
|
|
||||||
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
|
||||||
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
|
||||||
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
|
||||||
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
|
||||||
assert len(interpolation_factor) == len(
|
|
||||||
rope_dim_list
|
|
||||||
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
|
||||||
|
|
||||||
# use 1/ndim of dimensions to encode grid_axis
|
|
||||||
embs = []
|
|
||||||
for i in range(len(rope_dim_list)):
|
|
||||||
emb = get_1d_rotary_pos_embed(
|
|
||||||
rope_dim_list[i],
|
|
||||||
grid[i].reshape(-1),
|
|
||||||
theta,
|
|
||||||
use_real=use_real,
|
|
||||||
theta_rescale_factor=theta_rescale_factor[i],
|
|
||||||
interpolation_factor=interpolation_factor[i],
|
|
||||||
) # 2 x [WHD, rope_dim_list[i]]
|
|
||||||
embs.append(emb)
|
|
||||||
|
|
||||||
if use_real:
|
|
||||||
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
|
||||||
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
|
||||||
return cos, sin
|
|
||||||
else:
|
|
||||||
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
|
||||||
return emb
|
|
||||||
|
|
||||||
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
|
||||||
[16, 56, 56],
|
|
||||||
[latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
|
|
||||||
theta=256,
|
|
||||||
use_real=True,
|
|
||||||
theta_rescale_factor=1,
|
|
||||||
)
|
|
||||||
return freqs_cos, freqs_sin
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbed(torch.nn.Module):
|
|
||||||
def __init__(self, patch_size=(1, 2, 2), in_channels=16, embed_dim=3072):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = torch.nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.proj(x)
|
|
||||||
x = x.flatten(2).transpose(1, 2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class IndividualTokenRefinerBlock(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=3072, num_heads=24):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
|
||||||
self.self_attn_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
|
||||||
self.self_attn_proj = torch.nn.Linear(hidden_size, hidden_size)
|
|
||||||
|
|
||||||
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
|
||||||
self.mlp = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size * 4),
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(hidden_size * 4, hidden_size)
|
|
||||||
)
|
|
||||||
self.adaLN_modulation = torch.nn.Sequential(
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size * 2, device="cuda", dtype=torch.bfloat16),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, c, attn_mask=None):
|
|
||||||
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
|
||||||
|
|
||||||
norm_x = self.norm1(x)
|
|
||||||
qkv = self.self_attn_qkv(norm_x)
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
||||||
|
|
||||||
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
||||||
attn = rearrange(attn, "B H L D -> B L (H D)")
|
|
||||||
|
|
||||||
x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1)
|
|
||||||
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SingleTokenRefiner(torch.nn.Module):
|
|
||||||
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
|
|
||||||
super().__init__()
|
|
||||||
self.input_embedder = torch.nn.Linear(in_channels, hidden_size, bias=True)
|
|
||||||
self.t_embedder = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
|
||||||
self.c_embedder = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(in_channels, hidden_size),
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size)
|
|
||||||
)
|
|
||||||
self.blocks = torch.nn.ModuleList([IndividualTokenRefinerBlock(hidden_size=hidden_size) for _ in range(depth)])
|
|
||||||
|
|
||||||
def forward(self, x, t, mask=None):
|
|
||||||
timestep_aware_representations = self.t_embedder(t, dtype=torch.float32)
|
|
||||||
|
|
||||||
mask_float = mask.float().unsqueeze(-1)
|
|
||||||
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
|
||||||
context_aware_representations = self.c_embedder(context_aware_representations)
|
|
||||||
c = timestep_aware_representations + context_aware_representations
|
|
||||||
|
|
||||||
x = self.input_embedder(x)
|
|
||||||
|
|
||||||
mask = mask.to(device=x.device, dtype=torch.bool)
|
|
||||||
mask = repeat(mask, "B L -> B 1 D L", D=mask.shape[-1])
|
|
||||||
mask = mask & mask.transpose(2, 3)
|
|
||||||
mask[:, :, :, 0] = True
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x, c, mask)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ModulateDiT(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size, factor=6):
|
|
||||||
super().__init__()
|
|
||||||
self.act = torch.nn.SiLU()
|
|
||||||
self.linear = torch.nn.Linear(hidden_size, factor * hidden_size)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.linear(self.act(x))
|
|
||||||
|
|
||||||
|
|
||||||
def modulate(x, shift=None, scale=None):
|
|
||||||
if scale is None and shift is None:
|
|
||||||
return x
|
|
||||||
elif shift is None:
|
|
||||||
return x * (1 + scale.unsqueeze(1))
|
|
||||||
elif scale is None:
|
|
||||||
return x + shift.unsqueeze(1)
|
|
||||||
else:
|
|
||||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_for_broadcast(
|
|
||||||
freqs_cis,
|
|
||||||
x: torch.Tensor,
|
|
||||||
head_first=False,
|
|
||||||
):
|
|
||||||
ndim = x.ndim
|
|
||||||
assert 0 <= 1 < ndim
|
|
||||||
|
|
||||||
if isinstance(freqs_cis, tuple):
|
|
||||||
# freqs_cis: (cos, sin) in real space
|
|
||||||
if head_first:
|
|
||||||
assert freqs_cis[0].shape == (
|
|
||||||
x.shape[-2],
|
|
||||||
x.shape[-1],
|
|
||||||
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
|
||||||
shape = [
|
|
||||||
d if i == ndim - 2 or i == ndim - 1 else 1
|
|
||||||
for i, d in enumerate(x.shape)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
assert freqs_cis[0].shape == (
|
|
||||||
x.shape[1],
|
|
||||||
x.shape[-1],
|
|
||||||
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
|
||||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
||||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
|
||||||
else:
|
|
||||||
# freqs_cis: values in complex space
|
|
||||||
if head_first:
|
|
||||||
assert freqs_cis.shape == (
|
|
||||||
x.shape[-2],
|
|
||||||
x.shape[-1],
|
|
||||||
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
|
||||||
shape = [
|
|
||||||
d if i == ndim - 2 or i == ndim - 1 else 1
|
|
||||||
for i, d in enumerate(x.shape)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
assert freqs_cis.shape == (
|
|
||||||
x.shape[1],
|
|
||||||
x.shape[-1],
|
|
||||||
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
|
||||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
||||||
return freqs_cis.view(*shape)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
x_real, x_imag = (
|
|
||||||
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
|
||||||
) # [B, S, H, D//2]
|
|
||||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(
|
|
||||||
xq: torch.Tensor,
|
|
||||||
xk: torch.Tensor,
|
|
||||||
freqs_cis,
|
|
||||||
head_first: bool = False,
|
|
||||||
):
|
|
||||||
xk_out = None
|
|
||||||
if isinstance(freqs_cis, tuple):
|
|
||||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
|
||||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
|
||||||
# real * cos - imag * sin
|
|
||||||
# imag * cos + real * sin
|
|
||||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
|
||||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
|
||||||
else:
|
|
||||||
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
|
||||||
xq_ = torch.view_as_complex(
|
|
||||||
xq.float().reshape(*xq.shape[:-1], -1, 2)
|
|
||||||
) # [B, S, H, D//2]
|
|
||||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
|
|
||||||
xq.device
|
|
||||||
) # [S, D//2] --> [1, S, 1, D//2]
|
|
||||||
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
|
||||||
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
|
||||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
|
||||||
xk_ = torch.view_as_complex(
|
|
||||||
xk.float().reshape(*xk.shape[:-1], -1, 2)
|
|
||||||
) # [B, S, H, D//2]
|
|
||||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
|
||||||
|
|
||||||
return xq_out, xk_out
|
|
||||||
|
|
||||||
|
|
||||||
def attention(q, k, v):
|
|
||||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
|
||||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
|
||||||
x = x.transpose(1, 2).flatten(2, 3)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
|
||||||
super().__init__()
|
|
||||||
self.heads_num = heads_num
|
|
||||||
|
|
||||||
self.mod = ModulateDiT(hidden_size)
|
|
||||||
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
|
|
||||||
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
|
||||||
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
|
||||||
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
|
||||||
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
|
||||||
|
|
||||||
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.ff = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, conditioning, freqs_cis=None):
|
|
||||||
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
|
||||||
|
|
||||||
norm_hidden_states = self.norm1(hidden_states)
|
|
||||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
|
|
||||||
qkv = self.to_qkv(norm_hidden_states)
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
|
||||||
|
|
||||||
q = self.norm_q(q)
|
|
||||||
k = self.norm_k(k)
|
|
||||||
|
|
||||||
if freqs_cis is not None:
|
|
||||||
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
|
||||||
|
|
||||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
|
|
||||||
|
|
||||||
def process_ff(self, hidden_states, attn_output, mod):
|
|
||||||
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
|
||||||
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
|
|
||||||
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class MMDoubleStreamBlock(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
|
||||||
super().__init__()
|
|
||||||
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
|
||||||
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
|
|
||||||
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
|
|
||||||
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
|
||||||
|
|
||||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
|
||||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
|
||||||
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
|
|
||||||
attn_output_a = attention(q_a, k_a, v_a)
|
|
||||||
attn_output_b = attention(q_b, k_b, v_b)
|
|
||||||
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
|
|
||||||
|
|
||||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
|
|
||||||
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
|
||||||
return hidden_states_a, hidden_states_b
|
|
||||||
|
|
||||||
|
|
||||||
class MMSingleStreamBlockOriginal(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.heads_num = heads_num
|
|
||||||
self.mlp_hidden_dim = hidden_size * mlp_width_ratio
|
|
||||||
|
|
||||||
self.linear1 = torch.nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
|
||||||
self.linear2 = torch.nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
|
||||||
|
|
||||||
self.q_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
|
||||||
self.k_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
|
||||||
|
|
||||||
self.pre_norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
|
|
||||||
self.mlp_act = torch.nn.GELU(approximate="tanh")
|
|
||||||
self.modulation = ModulateDiT(hidden_size, factor=3)
|
|
||||||
|
|
||||||
def forward(self, x, vec, freqs_cis=None, txt_len=256):
|
|
||||||
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
|
||||||
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
|
||||||
q = self.q_norm(q)
|
|
||||||
k = self.k_norm(k)
|
|
||||||
|
|
||||||
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
|
||||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
|
||||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
|
||||||
q = torch.cat((q_a, q_b), dim=1)
|
|
||||||
k = torch.cat((k_a, k_b), dim=1)
|
|
||||||
|
|
||||||
attn_output_a = attention(q[:, :-185].contiguous(), k[:, :-185].contiguous(), v[:, :-185].contiguous())
|
|
||||||
attn_output_b = attention(q[:, -185:].contiguous(), k[:, -185:].contiguous(), v[:, -185:].contiguous())
|
|
||||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
|
||||||
|
|
||||||
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
|
|
||||||
return x + output * mod_gate.unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
class MMSingleStreamBlock(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
|
||||||
super().__init__()
|
|
||||||
self.heads_num = heads_num
|
|
||||||
|
|
||||||
self.mod = ModulateDiT(hidden_size, factor=3)
|
|
||||||
self.norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
|
|
||||||
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
|
||||||
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
|
||||||
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
|
||||||
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
|
||||||
|
|
||||||
self.ff = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
|
|
||||||
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
|
||||||
|
|
||||||
norm_hidden_states = self.norm(hidden_states)
|
|
||||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
|
|
||||||
qkv = self.to_qkv(norm_hidden_states)
|
|
||||||
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
|
||||||
|
|
||||||
q = self.norm_q(q)
|
|
||||||
k = self.norm_k(k)
|
|
||||||
|
|
||||||
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
|
||||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
|
||||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
|
||||||
|
|
||||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
|
||||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
|
||||||
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
|
|
||||||
|
|
||||||
attn_output_a = attention(q_a, k_a, v_a)
|
|
||||||
attn_output_b = attention(q_b, k_b, v_b)
|
|
||||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
|
|
||||||
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class FinalLayer(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size=3072, patch_size=(1, 2, 2), out_channels=16):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.norm_final = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.linear = torch.nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels)
|
|
||||||
|
|
||||||
self.adaLN_modulation = torch.nn.Sequential(torch.nn.SiLU(), torch.nn.Linear(hidden_size, 2 * hidden_size))
|
|
||||||
|
|
||||||
def forward(self, x, c):
|
|
||||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
|
||||||
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
|
||||||
x = self.linear(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoDiT(torch.nn.Module):
|
|
||||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
|
|
||||||
super().__init__()
|
|
||||||
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
|
||||||
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
|
||||||
self.time_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
|
||||||
self.vector_in = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(768, hidden_size),
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(hidden_size, hidden_size)
|
|
||||||
)
|
|
||||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
|
||||||
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
|
||||||
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
|
||||||
self.final_layer = FinalLayer(hidden_size)
|
|
||||||
|
|
||||||
# TODO: remove these parameters
|
|
||||||
self.dtype = torch.bfloat16
|
|
||||||
self.patch_size = [1, 2, 2]
|
|
||||||
self.hidden_size = 3072
|
|
||||||
self.heads_num = 24
|
|
||||||
self.rope_dim_list = [16, 56, 56]
|
|
||||||
|
|
||||||
def unpatchify(self, x, T, H, W):
|
|
||||||
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
|
|
||||||
self.warm_device = warm_device
|
|
||||||
self.cold_device = cold_device
|
|
||||||
self.to(self.cold_device)
|
|
||||||
|
|
||||||
def load_models_to_device(self, loadmodel_names=[], device="cpu"):
|
|
||||||
for model_name in loadmodel_names:
|
|
||||||
model = getattr(self, model_name)
|
|
||||||
if model is not None:
|
|
||||||
model.to(device)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def prepare_freqs(self, latents):
|
|
||||||
return HunyuanVideoRope(latents)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
t: torch.Tensor,
|
|
||||||
prompt_emb: torch.Tensor = None,
|
|
||||||
text_mask: torch.Tensor = None,
|
|
||||||
pooled_prompt_emb: torch.Tensor = None,
|
|
||||||
freqs_cos: torch.Tensor = None,
|
|
||||||
freqs_sin: torch.Tensor = None,
|
|
||||||
guidance: torch.Tensor = None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
B, C, T, H, W = x.shape
|
|
||||||
|
|
||||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
|
|
||||||
img = self.img_in(x)
|
|
||||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
|
||||||
|
|
||||||
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
|
||||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
|
||||||
|
|
||||||
x = torch.concat([img, txt], dim=1)
|
|
||||||
for block in tqdm(self.single_blocks, desc="Single stream blocks"):
|
|
||||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
|
||||||
|
|
||||||
img = x[:, :-256]
|
|
||||||
img = self.final_layer(img, vec)
|
|
||||||
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
|
|
||||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
|
||||||
if device is None or weight.device == device:
|
|
||||||
if not copy:
|
|
||||||
if dtype is None or weight.dtype == dtype:
|
|
||||||
return weight
|
|
||||||
return weight.to(dtype=dtype, copy=copy)
|
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
r.copy_(weight)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def cast_weight(s, input=None, dtype=None, device=None):
|
|
||||||
if input is not None:
|
|
||||||
if dtype is None:
|
|
||||||
dtype = input.dtype
|
|
||||||
if device is None:
|
|
||||||
device = input.device
|
|
||||||
weight = cast_to(s.weight, dtype, device)
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|
||||||
if input is not None:
|
|
||||||
if dtype is None:
|
|
||||||
dtype = input.dtype
|
|
||||||
if bias_dtype is None:
|
|
||||||
bias_dtype = dtype
|
|
||||||
if device is None:
|
|
||||||
device = input.device
|
|
||||||
weight = cast_to(s.weight, dtype, device)
|
|
||||||
bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None
|
|
||||||
return weight, bias
|
|
||||||
|
|
||||||
class quantized_layer:
|
|
||||||
class Linear(torch.nn.Linear):
|
|
||||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def block_forward_(self, x, i, j, dtype, device):
|
|
||||||
weight_ = cast_to(
|
|
||||||
self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size],
|
|
||||||
dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
if self.bias is None or i > 0:
|
|
||||||
bias_ = None
|
|
||||||
else:
|
|
||||||
bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device)
|
|
||||||
x_ = x[..., i * self.block_size: (i + 1) * self.block_size]
|
|
||||||
y_ = torch.nn.functional.linear(x_, weight_, bias_)
|
|
||||||
del x_, weight_, bias_
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return y_
|
|
||||||
|
|
||||||
def block_forward(self, x, **kwargs):
|
|
||||||
# This feature can only reduce 2GB VRAM, so we disable it.
|
|
||||||
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
|
|
||||||
for i in range((self.in_features + self.block_size - 1) // self.block_size):
|
|
||||||
for j in range((self.out_features + self.block_size - 1) // self.block_size):
|
|
||||||
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
|
|
||||||
return y
|
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
|
||||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
|
||||||
return torch.nn.functional.linear(x, weight, bias)
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
|
|
||||||
super().__init__()
|
|
||||||
self.module = module
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def forward(self, hidden_states, **kwargs):
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
|
||||||
hidden_states = hidden_states.to(input_dtype)
|
|
||||||
if self.module.weight is not None:
|
|
||||||
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
|
|
||||||
hidden_states = hidden_states * weight
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
class Conv3d(torch.nn.Conv3d):
|
|
||||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
|
||||||
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
|
||||||
|
|
||||||
class LayerNorm(torch.nn.LayerNorm):
|
|
||||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.weight is not None and self.bias is not None:
|
|
||||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
|
||||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
|
||||||
else:
|
|
||||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
||||||
|
|
||||||
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
|
|
||||||
for name, module in model.named_children():
|
|
||||||
if isinstance(module, torch.nn.Linear):
|
|
||||||
with init_weights_on_device():
|
|
||||||
new_layer = quantized_layer.Linear(
|
|
||||||
module.in_features, module.out_features, bias=module.bias is not None,
|
|
||||||
dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
elif isinstance(module, torch.nn.Conv3d):
|
|
||||||
with init_weights_on_device():
|
|
||||||
new_layer = quantized_layer.Conv3d(
|
|
||||||
module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride,
|
|
||||||
dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
elif isinstance(module, RMSNorm):
|
|
||||||
new_layer = quantized_layer.RMSNorm(
|
|
||||||
module,
|
|
||||||
dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
elif isinstance(module, torch.nn.LayerNorm):
|
|
||||||
with init_weights_on_device():
|
|
||||||
new_layer = quantized_layer.LayerNorm(
|
|
||||||
module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps,
|
|
||||||
dtype=dtype, device=device
|
|
||||||
)
|
|
||||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
|
||||||
setattr(model, name, new_layer)
|
|
||||||
else:
|
|
||||||
replace_layer(module, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
replace_layer(self, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanVideoDiTStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoDiTStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
if "module" in state_dict:
|
|
||||||
state_dict = state_dict["module"]
|
|
||||||
direct_dict = {
|
|
||||||
"img_in.proj": "img_in.proj",
|
|
||||||
"time_in.mlp.0": "time_in.timestep_embedder.0",
|
|
||||||
"time_in.mlp.2": "time_in.timestep_embedder.2",
|
|
||||||
"vector_in.in_layer": "vector_in.0",
|
|
||||||
"vector_in.out_layer": "vector_in.2",
|
|
||||||
"guidance_in.mlp.0": "guidance_in.timestep_embedder.0",
|
|
||||||
"guidance_in.mlp.2": "guidance_in.timestep_embedder.2",
|
|
||||||
"txt_in.input_embedder": "txt_in.input_embedder",
|
|
||||||
"txt_in.t_embedder.mlp.0": "txt_in.t_embedder.timestep_embedder.0",
|
|
||||||
"txt_in.t_embedder.mlp.2": "txt_in.t_embedder.timestep_embedder.2",
|
|
||||||
"txt_in.c_embedder.linear_1": "txt_in.c_embedder.0",
|
|
||||||
"txt_in.c_embedder.linear_2": "txt_in.c_embedder.2",
|
|
||||||
"final_layer.linear": "final_layer.linear",
|
|
||||||
"final_layer.adaLN_modulation.1": "final_layer.adaLN_modulation.1",
|
|
||||||
}
|
|
||||||
txt_suffix_dict = {
|
|
||||||
"norm1": "norm1",
|
|
||||||
"self_attn_qkv": "self_attn_qkv",
|
|
||||||
"self_attn_proj": "self_attn_proj",
|
|
||||||
"norm2": "norm2",
|
|
||||||
"mlp.fc1": "mlp.0",
|
|
||||||
"mlp.fc2": "mlp.2",
|
|
||||||
"adaLN_modulation.1": "adaLN_modulation.1",
|
|
||||||
}
|
|
||||||
double_suffix_dict = {
|
|
||||||
"img_mod.linear": "component_a.mod.linear",
|
|
||||||
"img_attn_qkv": "component_a.to_qkv",
|
|
||||||
"img_attn_q_norm": "component_a.norm_q",
|
|
||||||
"img_attn_k_norm": "component_a.norm_k",
|
|
||||||
"img_attn_proj": "component_a.to_out",
|
|
||||||
"img_mlp.fc1": "component_a.ff.0",
|
|
||||||
"img_mlp.fc2": "component_a.ff.2",
|
|
||||||
"txt_mod.linear": "component_b.mod.linear",
|
|
||||||
"txt_attn_qkv": "component_b.to_qkv",
|
|
||||||
"txt_attn_q_norm": "component_b.norm_q",
|
|
||||||
"txt_attn_k_norm": "component_b.norm_k",
|
|
||||||
"txt_attn_proj": "component_b.to_out",
|
|
||||||
"txt_mlp.fc1": "component_b.ff.0",
|
|
||||||
"txt_mlp.fc2": "component_b.ff.2",
|
|
||||||
}
|
|
||||||
single_suffix_dict = {
|
|
||||||
"linear1": ["to_qkv", "ff.0"],
|
|
||||||
"linear2": ["to_out", "ff.2"],
|
|
||||||
"q_norm": "norm_q",
|
|
||||||
"k_norm": "norm_k",
|
|
||||||
"modulation.linear": "mod.linear",
|
|
||||||
}
|
|
||||||
# single_suffix_dict = {
|
|
||||||
# "linear1": "linear1",
|
|
||||||
# "linear2": "linear2",
|
|
||||||
# "q_norm": "q_norm",
|
|
||||||
# "k_norm": "k_norm",
|
|
||||||
# "modulation.linear": "modulation.linear",
|
|
||||||
# }
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
names = name.split(".")
|
|
||||||
direct_name = ".".join(names[:-1])
|
|
||||||
if direct_name in direct_dict:
|
|
||||||
name_ = direct_dict[direct_name] + "." + names[-1]
|
|
||||||
state_dict_[name_] = param
|
|
||||||
elif names[0] == "double_blocks":
|
|
||||||
prefix = ".".join(names[:2])
|
|
||||||
suffix = ".".join(names[2:-1])
|
|
||||||
name_ = prefix + "." + double_suffix_dict[suffix] + "." + names[-1]
|
|
||||||
state_dict_[name_] = param
|
|
||||||
elif names[0] == "single_blocks":
|
|
||||||
prefix = ".".join(names[:2])
|
|
||||||
suffix = ".".join(names[2:-1])
|
|
||||||
if isinstance(single_suffix_dict[suffix], list):
|
|
||||||
if suffix == "linear1":
|
|
||||||
name_a, name_b = single_suffix_dict[suffix]
|
|
||||||
param_a, param_b = torch.split(param, (3072*3, 3072*4), dim=0)
|
|
||||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
|
||||||
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
|
||||||
elif suffix == "linear2":
|
|
||||||
if names[-1] == "weight":
|
|
||||||
name_a, name_b = single_suffix_dict[suffix]
|
|
||||||
param_a, param_b = torch.split(param, (3072*1, 3072*4), dim=-1)
|
|
||||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
|
||||||
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
|
||||||
else:
|
|
||||||
name_a, name_b = single_suffix_dict[suffix]
|
|
||||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
name_ = prefix + "." + single_suffix_dict[suffix] + "." + names[-1]
|
|
||||||
state_dict_[name_] = param
|
|
||||||
elif names[0] == "txt_in":
|
|
||||||
prefix = ".".join(names[:4]).replace(".individual_token_refiner.", ".")
|
|
||||||
suffix = ".".join(names[4:-1])
|
|
||||||
name_ = prefix + "." + txt_suffix_dict[suffix] + "." + names[-1]
|
|
||||||
state_dict_[name_] = param
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
return state_dict_
|
|
||||||
@@ -1,55 +0,0 @@
|
|||||||
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
|
||||||
from copy import deepcopy
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
|
||||||
def __init__(self, config: LlamaConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self.auto_offload = False
|
|
||||||
|
|
||||||
|
|
||||||
def enable_auto_offload(self, **kwargs):
|
|
||||||
self.auto_offload = True
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
hidden_state_skip_layer=2
|
|
||||||
):
|
|
||||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
|
||||||
inputs_embeds = embed_tokens(input_ids)
|
|
||||||
|
|
||||||
past_key_values = DynamicCache()
|
|
||||||
|
|
||||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
|
||||||
position_ids = cache_position.unsqueeze(0)
|
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
# create position embeddings to be shared across the decoder layers
|
|
||||||
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
|
|
||||||
position_embeddings = rotary_emb(hidden_states, position_ids)
|
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
for layer_id, decoder_layer in enumerate(self.layers):
|
|
||||||
if self.auto_offload:
|
|
||||||
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=False,
|
|
||||||
use_cache=True,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
)
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
|
|
||||||
break
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
@@ -1,507 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import rearrange
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
from einops import repeat
|
|
||||||
|
|
||||||
|
|
||||||
class CausalConv3d(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_channel, out_channel, kernel_size, stride=1, dilation=1, pad_mode='replicate', **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.pad_mode = pad_mode
|
|
||||||
self.time_causal_padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0
|
|
||||||
) # W, H, T
|
|
||||||
self.conv = nn.Conv3d(in_channel, out_channel, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
|
||||||
return self.conv(x)
|
|
||||||
|
|
||||||
|
|
||||||
class UpsampleCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, channels, use_conv=False, out_channels=None, kernel_size=None, bias=True, upsample_factor=(2, 2, 2)):
|
|
||||||
super().__init__()
|
|
||||||
self.channels = channels
|
|
||||||
self.out_channels = out_channels or channels
|
|
||||||
self.upsample_factor = upsample_factor
|
|
||||||
self.conv = None
|
|
||||||
if use_conv:
|
|
||||||
kernel_size = 3 if kernel_size is None else kernel_size
|
|
||||||
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
|
||||||
dtype = hidden_states.dtype
|
|
||||||
if dtype == torch.bfloat16:
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
|
|
||||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
|
||||||
if hidden_states.shape[0] >= 64:
|
|
||||||
hidden_states = hidden_states.contiguous()
|
|
||||||
|
|
||||||
# interpolate
|
|
||||||
B, C, T, H, W = hidden_states.shape
|
|
||||||
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
|
|
||||||
if T > 1:
|
|
||||||
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
|
|
||||||
first_h = F.interpolate(first_h.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest").unsqueeze(2)
|
|
||||||
hidden_states = torch.cat((first_h, other_h), dim=2) if T > 1 else first_h
|
|
||||||
|
|
||||||
# If the input is bfloat16, we cast back to bfloat16
|
|
||||||
if dtype == torch.bfloat16:
|
|
||||||
hidden_states = hidden_states.to(dtype)
|
|
||||||
|
|
||||||
if self.conv:
|
|
||||||
hidden_states = self.conv(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class ResnetBlockCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels=None, dropout=0.0, groups=32, eps=1e-6, conv_shortcut_bias=True):
|
|
||||||
super().__init__()
|
|
||||||
self.pre_norm = True
|
|
||||||
self.in_channels = in_channels
|
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
|
||||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
|
|
||||||
|
|
||||||
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
|
||||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1)
|
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.nonlinearity = nn.SiLU()
|
|
||||||
|
|
||||||
self.conv_shortcut = None
|
|
||||||
if in_channels != out_channels:
|
|
||||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=conv_shortcut_bias)
|
|
||||||
|
|
||||||
def forward(self, input_tensor):
|
|
||||||
hidden_states = input_tensor
|
|
||||||
# conv1
|
|
||||||
hidden_states = self.norm1(hidden_states)
|
|
||||||
hidden_states = self.nonlinearity(hidden_states)
|
|
||||||
hidden_states = self.conv1(hidden_states)
|
|
||||||
|
|
||||||
# conv2
|
|
||||||
hidden_states = self.norm2(hidden_states)
|
|
||||||
hidden_states = self.nonlinearity(hidden_states)
|
|
||||||
hidden_states = self.dropout(hidden_states)
|
|
||||||
hidden_states = self.conv2(hidden_states)
|
|
||||||
# shortcut
|
|
||||||
if self.conv_shortcut is not None:
|
|
||||||
input_tensor = (self.conv_shortcut(input_tensor))
|
|
||||||
# shortcut and scale
|
|
||||||
output_tensor = input_tensor + hidden_states
|
|
||||||
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_causal_attention_mask(n_frame, n_hw, dtype, device, batch_size=None):
|
|
||||||
seq_len = n_frame * n_hw
|
|
||||||
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
|
||||||
for i in range(seq_len):
|
|
||||||
i_frame = i // n_hw
|
|
||||||
mask[i, :(i_frame + 1) * n_hw] = 0
|
|
||||||
if batch_size is not None:
|
|
||||||
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
in_channels,
|
|
||||||
num_heads,
|
|
||||||
head_dim,
|
|
||||||
num_groups=32,
|
|
||||||
dropout=0.0,
|
|
||||||
eps=1e-6,
|
|
||||||
bias=True,
|
|
||||||
residual_connection=True):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.residual_connection = residual_connection
|
|
||||||
dim_inner = head_dim * num_heads
|
|
||||||
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
|
|
||||||
self.to_q = nn.Linear(in_channels, dim_inner, bias=bias)
|
|
||||||
self.to_k = nn.Linear(in_channels, dim_inner, bias=bias)
|
|
||||||
self.to_v = nn.Linear(in_channels, dim_inner, bias=bias)
|
|
||||||
self.to_out = nn.Sequential(nn.Linear(dim_inner, in_channels, bias=bias), nn.Dropout(dropout))
|
|
||||||
|
|
||||||
def forward(self, input_tensor, attn_mask=None):
|
|
||||||
hidden_states = self.group_norm(input_tensor.transpose(1, 2)).transpose(1, 2)
|
|
||||||
batch_size = hidden_states.shape[0]
|
|
||||||
|
|
||||||
q = self.to_q(hidden_states)
|
|
||||||
k = self.to_k(hidden_states)
|
|
||||||
v = self.to_v(hidden_states)
|
|
||||||
|
|
||||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
|
||||||
attn_mask = attn_mask.view(batch_size, self.num_heads, -1, attn_mask.shape[-1])
|
|
||||||
hidden_states = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
||||||
hidden_states = self.to_out(hidden_states)
|
|
||||||
if self.residual_connection:
|
|
||||||
output_tensor = input_tensor + hidden_states
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class UNetMidBlockCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_channels, dropout=0.0, num_layers=1, eps=1e-6, num_groups=32, attention_head_dim=None):
|
|
||||||
super().__init__()
|
|
||||||
resnets = [
|
|
||||||
ResnetBlockCausal3D(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=in_channels,
|
|
||||||
dropout=dropout,
|
|
||||||
groups=num_groups,
|
|
||||||
eps=eps,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
attentions = []
|
|
||||||
attention_head_dim = attention_head_dim or in_channels
|
|
||||||
|
|
||||||
for _ in range(num_layers):
|
|
||||||
attentions.append(
|
|
||||||
Attention(
|
|
||||||
in_channels,
|
|
||||||
num_heads=in_channels // attention_head_dim,
|
|
||||||
head_dim=attention_head_dim,
|
|
||||||
num_groups=num_groups,
|
|
||||||
dropout=dropout,
|
|
||||||
eps=eps,
|
|
||||||
bias=True,
|
|
||||||
residual_connection=True,
|
|
||||||
))
|
|
||||||
|
|
||||||
resnets.append(
|
|
||||||
ResnetBlockCausal3D(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=in_channels,
|
|
||||||
dropout=dropout,
|
|
||||||
groups=num_groups,
|
|
||||||
eps=eps,
|
|
||||||
))
|
|
||||||
|
|
||||||
self.attentions = nn.ModuleList(attentions)
|
|
||||||
self.resnets = nn.ModuleList(resnets)
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.resnets[0](hidden_states)
|
|
||||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
|
||||||
B, C, T, H, W = hidden_states.shape
|
|
||||||
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
|
|
||||||
attn_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
|
|
||||||
hidden_states = attn(hidden_states, attn_mask=attn_mask)
|
|
||||||
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
|
|
||||||
hidden_states = resnet(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class UpDecoderBlockCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
dropout=0.0,
|
|
||||||
num_layers=1,
|
|
||||||
eps=1e-6,
|
|
||||||
num_groups=32,
|
|
||||||
add_upsample=True,
|
|
||||||
upsample_scale_factor=(2, 2, 2),
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
resnets = []
|
|
||||||
for i in range(num_layers):
|
|
||||||
cur_in_channel = in_channels if i == 0 else out_channels
|
|
||||||
resnets.append(
|
|
||||||
ResnetBlockCausal3D(
|
|
||||||
in_channels=cur_in_channel,
|
|
||||||
out_channels=out_channels,
|
|
||||||
groups=num_groups,
|
|
||||||
dropout=dropout,
|
|
||||||
eps=eps,
|
|
||||||
))
|
|
||||||
self.resnets = nn.ModuleList(resnets)
|
|
||||||
|
|
||||||
self.upsamplers = None
|
|
||||||
if add_upsample:
|
|
||||||
self.upsamplers = nn.ModuleList([
|
|
||||||
UpsampleCausal3D(
|
|
||||||
out_channels,
|
|
||||||
use_conv=True,
|
|
||||||
out_channels=out_channels,
|
|
||||||
upsample_factor=upsample_scale_factor,
|
|
||||||
)
|
|
||||||
])
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
for resnet in self.resnets:
|
|
||||||
hidden_states = resnet(hidden_states)
|
|
||||||
if self.upsamplers is not None:
|
|
||||||
for upsampler in self.upsamplers:
|
|
||||||
hidden_states = upsampler(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class DecoderCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels=16,
|
|
||||||
out_channels=3,
|
|
||||||
eps=1e-6,
|
|
||||||
dropout=0.0,
|
|
||||||
block_out_channels=[128, 256, 512, 512],
|
|
||||||
layers_per_block=2,
|
|
||||||
num_groups=32,
|
|
||||||
time_compression_ratio=4,
|
|
||||||
spatial_compression_ratio=8,
|
|
||||||
gradient_checkpointing=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.layers_per_block = layers_per_block
|
|
||||||
|
|
||||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
|
|
||||||
self.up_blocks = nn.ModuleList([])
|
|
||||||
|
|
||||||
# mid
|
|
||||||
self.mid_block = UNetMidBlockCausal3D(
|
|
||||||
in_channels=block_out_channels[-1],
|
|
||||||
dropout=dropout,
|
|
||||||
eps=eps,
|
|
||||||
num_groups=num_groups,
|
|
||||||
attention_head_dim=block_out_channels[-1],
|
|
||||||
)
|
|
||||||
|
|
||||||
# up
|
|
||||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
|
||||||
output_channel = reversed_block_out_channels[0]
|
|
||||||
for i in range(len(block_out_channels)):
|
|
||||||
prev_output_channel = output_channel
|
|
||||||
output_channel = reversed_block_out_channels[i]
|
|
||||||
is_final_block = i == len(block_out_channels) - 1
|
|
||||||
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
|
||||||
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
|
||||||
|
|
||||||
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
|
||||||
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
|
|
||||||
|
|
||||||
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
|
||||||
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
|
|
||||||
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
|
||||||
|
|
||||||
up_block = UpDecoderBlockCausal3D(
|
|
||||||
in_channels=prev_output_channel,
|
|
||||||
out_channels=output_channel,
|
|
||||||
dropout=dropout,
|
|
||||||
num_layers=layers_per_block + 1,
|
|
||||||
eps=eps,
|
|
||||||
num_groups=num_groups,
|
|
||||||
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
|
||||||
upsample_scale_factor=upsample_scale_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.up_blocks.append(up_block)
|
|
||||||
prev_output_channel = output_channel
|
|
||||||
|
|
||||||
# out
|
|
||||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups, eps=eps)
|
|
||||||
self.conv_act = nn.SiLU()
|
|
||||||
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
|
||||||
|
|
||||||
self.gradient_checkpointing = gradient_checkpointing
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.conv_in(hidden_states)
|
|
||||||
if self.training and self.gradient_checkpointing:
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
# middle
|
|
||||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(self.mid_block),
|
|
||||||
hidden_states,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
# up
|
|
||||||
for up_block in self.up_blocks:
|
|
||||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(up_block),
|
|
||||||
hidden_states,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# middle
|
|
||||||
hidden_states = self.mid_block(hidden_states)
|
|
||||||
# up
|
|
||||||
for up_block in self.up_blocks:
|
|
||||||
hidden_states = up_block(hidden_states)
|
|
||||||
# post-process
|
|
||||||
hidden_states = self.conv_norm_out(hidden_states)
|
|
||||||
hidden_states = self.conv_act(hidden_states)
|
|
||||||
hidden_states = self.conv_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoVAEDecoder(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels=16,
|
|
||||||
out_channels=3,
|
|
||||||
eps=1e-6,
|
|
||||||
dropout=0.0,
|
|
||||||
block_out_channels=[128, 256, 512, 512],
|
|
||||||
layers_per_block=2,
|
|
||||||
num_groups=32,
|
|
||||||
time_compression_ratio=4,
|
|
||||||
spatial_compression_ratio=8,
|
|
||||||
gradient_checkpointing=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.decoder = DecoderCausal3D(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
eps=eps,
|
|
||||||
dropout=dropout,
|
|
||||||
block_out_channels=block_out_channels,
|
|
||||||
layers_per_block=layers_per_block,
|
|
||||||
num_groups=num_groups,
|
|
||||||
time_compression_ratio=time_compression_ratio,
|
|
||||||
spatial_compression_ratio=spatial_compression_ratio,
|
|
||||||
gradient_checkpointing=gradient_checkpointing,
|
|
||||||
)
|
|
||||||
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
|
||||||
self.scaling_factor = 0.476986
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, latents):
|
|
||||||
latents = latents / self.scaling_factor
|
|
||||||
latents = self.post_quant_conv(latents)
|
|
||||||
dec = self.decoder(latents)
|
|
||||||
return dec
|
|
||||||
|
|
||||||
|
|
||||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
|
||||||
x = torch.ones((length,))
|
|
||||||
if not left_bound:
|
|
||||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
|
||||||
if not right_bound:
|
|
||||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def build_mask(self, data, is_bound, border_width):
|
|
||||||
_, _, T, H, W = data.shape
|
|
||||||
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
|
||||||
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
|
||||||
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
|
||||||
|
|
||||||
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
|
||||||
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
|
||||||
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
|
||||||
|
|
||||||
mask = torch.stack([t, h, w]).min(dim=0).values
|
|
||||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
|
||||||
B, C, T, H, W = hidden_states.shape
|
|
||||||
size_t, size_h, size_w = tile_size
|
|
||||||
stride_t, stride_h, stride_w = tile_stride
|
|
||||||
|
|
||||||
# Split tasks
|
|
||||||
tasks = []
|
|
||||||
for t in range(0, T, stride_t):
|
|
||||||
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
|
||||||
for h in range(0, H, stride_h):
|
|
||||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
|
||||||
for w in range(0, W, stride_w):
|
|
||||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
|
||||||
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
|
||||||
tasks.append((t, t_, h, h_, w, w_))
|
|
||||||
|
|
||||||
# Run
|
|
||||||
torch_dtype = self.post_quant_conv.weight.dtype
|
|
||||||
data_device = hidden_states.device
|
|
||||||
computation_device = self.post_quant_conv.weight.device
|
|
||||||
|
|
||||||
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
|
||||||
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
|
||||||
|
|
||||||
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
|
||||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
|
||||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
|
||||||
if t > 0:
|
|
||||||
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
|
||||||
|
|
||||||
mask = self.build_mask(
|
|
||||||
hidden_states_batch,
|
|
||||||
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
|
||||||
border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
|
|
||||||
).to(dtype=torch_dtype, device=data_device)
|
|
||||||
|
|
||||||
target_t = 0 if t==0 else t * 4 + 1
|
|
||||||
target_h = h * 8
|
|
||||||
target_w = w * 8
|
|
||||||
values[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
target_t: target_t + hidden_states_batch.shape[2],
|
|
||||||
target_h: target_h + hidden_states_batch.shape[3],
|
|
||||||
target_w: target_w + hidden_states_batch.shape[4],
|
|
||||||
] += hidden_states_batch * mask
|
|
||||||
weight[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
target_t: target_t + hidden_states_batch.shape[2],
|
|
||||||
target_h: target_h + hidden_states_batch.shape[3],
|
|
||||||
target_w: target_w + hidden_states_batch.shape[4],
|
|
||||||
] += mask
|
|
||||||
return values / weight
|
|
||||||
|
|
||||||
|
|
||||||
def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
|
|
||||||
latents = latents.to(self.post_quant_conv.weight.dtype)
|
|
||||||
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanVideoVAEDecoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoVAEDecoderStateDictConverter:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {}
|
|
||||||
for name in state_dict:
|
|
||||||
if name.startswith('decoder.') or name.startswith('post_quant_conv.'):
|
|
||||||
state_dict_[name] = state_dict[name]
|
|
||||||
return state_dict_
|
|
||||||
@@ -1,307 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
|
|
||||||
|
|
||||||
|
|
||||||
class DownsampleCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.conv(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class DownEncoderBlockCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
dropout=0.0,
|
|
||||||
num_layers=1,
|
|
||||||
eps=1e-6,
|
|
||||||
num_groups=32,
|
|
||||||
add_downsample=True,
|
|
||||||
downsample_stride=2,
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
resnets = []
|
|
||||||
for i in range(num_layers):
|
|
||||||
cur_in_channel = in_channels if i == 0 else out_channels
|
|
||||||
resnets.append(
|
|
||||||
ResnetBlockCausal3D(
|
|
||||||
in_channels=cur_in_channel,
|
|
||||||
out_channels=out_channels,
|
|
||||||
groups=num_groups,
|
|
||||||
dropout=dropout,
|
|
||||||
eps=eps,
|
|
||||||
))
|
|
||||||
self.resnets = nn.ModuleList(resnets)
|
|
||||||
|
|
||||||
self.downsamplers = None
|
|
||||||
if add_downsample:
|
|
||||||
self.downsamplers = nn.ModuleList([DownsampleCausal3D(
|
|
||||||
out_channels,
|
|
||||||
out_channels,
|
|
||||||
stride=downsample_stride,
|
|
||||||
)])
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
for resnet in self.resnets:
|
|
||||||
hidden_states = resnet(hidden_states)
|
|
||||||
|
|
||||||
if self.downsamplers is not None:
|
|
||||||
for downsampler in self.downsamplers:
|
|
||||||
hidden_states = downsampler(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderCausal3D(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int = 3,
|
|
||||||
out_channels: int = 16,
|
|
||||||
eps=1e-6,
|
|
||||||
dropout=0.0,
|
|
||||||
block_out_channels=[128, 256, 512, 512],
|
|
||||||
layers_per_block=2,
|
|
||||||
num_groups=32,
|
|
||||||
time_compression_ratio: int = 4,
|
|
||||||
spatial_compression_ratio: int = 8,
|
|
||||||
gradient_checkpointing=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
|
||||||
self.down_blocks = nn.ModuleList([])
|
|
||||||
|
|
||||||
# down
|
|
||||||
output_channel = block_out_channels[0]
|
|
||||||
for i in range(len(block_out_channels)):
|
|
||||||
input_channel = output_channel
|
|
||||||
output_channel = block_out_channels[i]
|
|
||||||
is_final_block = i == len(block_out_channels) - 1
|
|
||||||
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
|
||||||
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
|
||||||
|
|
||||||
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
|
||||||
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
|
|
||||||
|
|
||||||
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
|
||||||
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
|
||||||
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
|
||||||
down_block = DownEncoderBlockCausal3D(
|
|
||||||
in_channels=input_channel,
|
|
||||||
out_channels=output_channel,
|
|
||||||
dropout=dropout,
|
|
||||||
num_layers=layers_per_block,
|
|
||||||
eps=eps,
|
|
||||||
num_groups=num_groups,
|
|
||||||
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
|
||||||
downsample_stride=downsample_stride,
|
|
||||||
)
|
|
||||||
self.down_blocks.append(down_block)
|
|
||||||
|
|
||||||
# mid
|
|
||||||
self.mid_block = UNetMidBlockCausal3D(
|
|
||||||
in_channels=block_out_channels[-1],
|
|
||||||
dropout=dropout,
|
|
||||||
eps=eps,
|
|
||||||
num_groups=num_groups,
|
|
||||||
attention_head_dim=block_out_channels[-1],
|
|
||||||
)
|
|
||||||
# out
|
|
||||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
|
|
||||||
self.conv_act = nn.SiLU()
|
|
||||||
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)
|
|
||||||
|
|
||||||
self.gradient_checkpointing = gradient_checkpointing
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
hidden_states = self.conv_in(hidden_states)
|
|
||||||
if self.training and self.gradient_checkpointing:
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
# down
|
|
||||||
for down_block in self.down_blocks:
|
|
||||||
torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(down_block),
|
|
||||||
hidden_states,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
# middle
|
|
||||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(self.mid_block),
|
|
||||||
hidden_states,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# down
|
|
||||||
for down_block in self.down_blocks:
|
|
||||||
hidden_states = down_block(hidden_states)
|
|
||||||
# middle
|
|
||||||
hidden_states = self.mid_block(hidden_states)
|
|
||||||
# post-process
|
|
||||||
hidden_states = self.conv_norm_out(hidden_states)
|
|
||||||
hidden_states = self.conv_act(hidden_states)
|
|
||||||
hidden_states = self.conv_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoVAEEncoder(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels=3,
|
|
||||||
out_channels=16,
|
|
||||||
eps=1e-6,
|
|
||||||
dropout=0.0,
|
|
||||||
block_out_channels=[128, 256, 512, 512],
|
|
||||||
layers_per_block=2,
|
|
||||||
num_groups=32,
|
|
||||||
time_compression_ratio=4,
|
|
||||||
spatial_compression_ratio=8,
|
|
||||||
gradient_checkpointing=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.encoder = EncoderCausal3D(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
eps=eps,
|
|
||||||
dropout=dropout,
|
|
||||||
block_out_channels=block_out_channels,
|
|
||||||
layers_per_block=layers_per_block,
|
|
||||||
num_groups=num_groups,
|
|
||||||
time_compression_ratio=time_compression_ratio,
|
|
||||||
spatial_compression_ratio=spatial_compression_ratio,
|
|
||||||
gradient_checkpointing=gradient_checkpointing,
|
|
||||||
)
|
|
||||||
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
|
||||||
self.scaling_factor = 0.476986
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, images):
|
|
||||||
latents = self.encoder(images)
|
|
||||||
latents = self.quant_conv(latents)
|
|
||||||
latents = latents[:, :16]
|
|
||||||
latents = latents * self.scaling_factor
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
|
||||||
x = torch.ones((length,))
|
|
||||||
if not left_bound:
|
|
||||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
|
||||||
if not right_bound:
|
|
||||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def build_mask(self, data, is_bound, border_width):
|
|
||||||
_, _, T, H, W = data.shape
|
|
||||||
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
|
||||||
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
|
||||||
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
|
||||||
|
|
||||||
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
|
||||||
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
|
||||||
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
|
||||||
|
|
||||||
mask = torch.stack([t, h, w]).min(dim=0).values
|
|
||||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
|
||||||
B, C, T, H, W = hidden_states.shape
|
|
||||||
size_t, size_h, size_w = tile_size
|
|
||||||
stride_t, stride_h, stride_w = tile_stride
|
|
||||||
|
|
||||||
# Split tasks
|
|
||||||
tasks = []
|
|
||||||
for t in range(0, T, stride_t):
|
|
||||||
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
|
||||||
for h in range(0, H, stride_h):
|
|
||||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
|
||||||
for w in range(0, W, stride_w):
|
|
||||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
|
||||||
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
|
||||||
tasks.append((t, t_, h, h_, w, w_))
|
|
||||||
|
|
||||||
# Run
|
|
||||||
torch_dtype = self.quant_conv.weight.dtype
|
|
||||||
data_device = hidden_states.device
|
|
||||||
computation_device = self.quant_conv.weight.device
|
|
||||||
|
|
||||||
weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
|
||||||
values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
|
||||||
|
|
||||||
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
|
||||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
|
||||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
|
||||||
if t > 0:
|
|
||||||
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
|
||||||
|
|
||||||
mask = self.build_mask(
|
|
||||||
hidden_states_batch,
|
|
||||||
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
|
||||||
border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8)
|
|
||||||
).to(dtype=torch_dtype, device=data_device)
|
|
||||||
|
|
||||||
target_t = 0 if t==0 else t // 4 + 1
|
|
||||||
target_h = h // 8
|
|
||||||
target_w = w // 8
|
|
||||||
values[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
target_t: target_t + hidden_states_batch.shape[2],
|
|
||||||
target_h: target_h + hidden_states_batch.shape[3],
|
|
||||||
target_w: target_w + hidden_states_batch.shape[4],
|
|
||||||
] += hidden_states_batch * mask
|
|
||||||
weight[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
target_t: target_t + hidden_states_batch.shape[2],
|
|
||||||
target_h: target_h + hidden_states_batch.shape[3],
|
|
||||||
target_w: target_w + hidden_states_batch.shape[4],
|
|
||||||
] += mask
|
|
||||||
return values / weight
|
|
||||||
|
|
||||||
|
|
||||||
def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)):
|
|
||||||
latents = latents.to(self.quant_conv.weight.dtype)
|
|
||||||
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return HunyuanVideoVAEEncoderStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoVAEEncoderStateDictConverter:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict_ = {}
|
|
||||||
for name in state_dict:
|
|
||||||
if name.startswith('encoder.') or name.startswith('quant_conv.'):
|
|
||||||
state_dict_[name] = state_dict[name]
|
|
||||||
return state_dict_
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,367 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .sd_unet import SDUNet
|
|
||||||
from .sdxl_unet import SDXLUNet
|
|
||||||
from .sd_text_encoder import SDTextEncoder
|
|
||||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|
||||||
from .sd3_dit import SD3DiT
|
|
||||||
from .flux_dit import FluxDiT
|
|
||||||
from .hunyuan_dit import HunyuanDiT
|
|
||||||
from .cog_dit import CogDiT
|
|
||||||
from .hunyuan_video_dit import HunyuanVideoDiT
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAFromCivitai:
|
|
||||||
def __init__(self):
|
|
||||||
self.supported_model_classes = []
|
|
||||||
self.lora_prefix = []
|
|
||||||
self.renamed_lora_prefix = {}
|
|
||||||
self.special_keys = {}
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_up" in key:
|
|
||||||
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
|
|
||||||
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
|
||||||
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
|
||||||
state_dict_ = {}
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_up" not in key:
|
|
||||||
continue
|
|
||||||
if not key.startswith(lora_prefix):
|
|
||||||
continue
|
|
||||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
|
||||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
|
||||||
if len(weight_up.shape) == 4:
|
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
|
||||||
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
|
|
||||||
for special_key in self.special_keys:
|
|
||||||
target_name = target_name.replace(special_key, self.special_keys[special_key])
|
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
|
||||||
state_dict_ = {}
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_B." not in key:
|
|
||||||
continue
|
|
||||||
if not key.startswith(lora_prefix):
|
|
||||||
continue
|
|
||||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
|
||||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
|
||||||
if len(weight_up.shape) == 4:
|
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
|
||||||
keys = key.split(".")
|
|
||||||
keys.pop(keys.index("lora_B"))
|
|
||||||
target_name = ".".join(keys)
|
|
||||||
target_name = target_name[len(lora_prefix):]
|
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
|
||||||
if model_resource == "diffusers":
|
|
||||||
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
|
|
||||||
elif model_resource == "civitai":
|
|
||||||
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
|
|
||||||
if isinstance(state_dict_lora, tuple):
|
|
||||||
state_dict_lora = state_dict_lora[0]
|
|
||||||
if len(state_dict_lora) > 0:
|
|
||||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
|
||||||
for name in state_dict_lora:
|
|
||||||
fp8=False
|
|
||||||
if state_dict_model[name].dtype == torch.float8_e4m3fn:
|
|
||||||
state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype)
|
|
||||||
fp8=True
|
|
||||||
state_dict_model[name] += state_dict_lora[name].to(
|
|
||||||
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
|
||||||
if fp8:
|
|
||||||
state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn)
|
|
||||||
model.load_state_dict(state_dict_model)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, model, state_dict_lora):
|
|
||||||
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
|
|
||||||
if not isinstance(model, model_class):
|
|
||||||
continue
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
for model_resource in ["diffusers", "civitai"]:
|
|
||||||
try:
|
|
||||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
|
||||||
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
|
||||||
else model.__class__.state_dict_converter().from_civitai
|
|
||||||
state_dict_lora_ = converter_fn(state_dict_lora_)
|
|
||||||
if isinstance(state_dict_lora_, tuple):
|
|
||||||
state_dict_lora_ = state_dict_lora_[0]
|
|
||||||
if len(state_dict_lora_) == 0:
|
|
||||||
continue
|
|
||||||
for name in state_dict_lora_:
|
|
||||||
if name not in state_dict_model:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
return lora_prefix, model_resource
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SDLoRAFromCivitai(LoRAFromCivitai):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.supported_model_classes = [SDUNet, SDTextEncoder]
|
|
||||||
self.lora_prefix = ["lora_unet_", "lora_te_"]
|
|
||||||
self.special_keys = {
|
|
||||||
"down.blocks": "down_blocks",
|
|
||||||
"up.blocks": "up_blocks",
|
|
||||||
"mid.block": "mid_block",
|
|
||||||
"proj.in": "proj_in",
|
|
||||||
"proj.out": "proj_out",
|
|
||||||
"transformer.blocks": "transformer_blocks",
|
|
||||||
"to.q": "to_q",
|
|
||||||
"to.k": "to_k",
|
|
||||||
"to.v": "to_v",
|
|
||||||
"to.out": "to_out",
|
|
||||||
"text.model": "text_model",
|
|
||||||
"self.attn.q.proj": "self_attn.q_proj",
|
|
||||||
"self.attn.k.proj": "self_attn.k_proj",
|
|
||||||
"self.attn.v.proj": "self_attn.v_proj",
|
|
||||||
"self.attn.out.proj": "self_attn.out_proj",
|
|
||||||
"input.blocks": "model.diffusion_model.input_blocks",
|
|
||||||
"middle.block": "model.diffusion_model.middle_block",
|
|
||||||
"output.blocks": "model.diffusion_model.output_blocks",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
|
|
||||||
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
|
|
||||||
self.renamed_lora_prefix = {"lora_te2_": "2"}
|
|
||||||
self.special_keys = {
|
|
||||||
"down.blocks": "down_blocks",
|
|
||||||
"up.blocks": "up_blocks",
|
|
||||||
"mid.block": "mid_block",
|
|
||||||
"proj.in": "proj_in",
|
|
||||||
"proj.out": "proj_out",
|
|
||||||
"transformer.blocks": "transformer_blocks",
|
|
||||||
"to.q": "to_q",
|
|
||||||
"to.k": "to_k",
|
|
||||||
"to.v": "to_v",
|
|
||||||
"to.out": "to_out",
|
|
||||||
"text.model": "conditioner.embedders.0.transformer.text_model",
|
|
||||||
"self.attn.q.proj": "self_attn.q_proj",
|
|
||||||
"self.attn.k.proj": "self_attn.k_proj",
|
|
||||||
"self.attn.v.proj": "self_attn.v_proj",
|
|
||||||
"self.attn.out.proj": "self_attn.out_proj",
|
|
||||||
"input.blocks": "model.diffusion_model.input_blocks",
|
|
||||||
"middle.block": "model.diffusion_model.middle_block",
|
|
||||||
"output.blocks": "model.diffusion_model.output_blocks",
|
|
||||||
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class FluxLoRAFromCivitai(LoRAFromCivitai):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.supported_model_classes = [FluxDiT, FluxDiT]
|
|
||||||
self.lora_prefix = ["lora_unet_", "transformer."]
|
|
||||||
self.renamed_lora_prefix = {}
|
|
||||||
self.special_keys = {
|
|
||||||
"single.blocks": "single_blocks",
|
|
||||||
"double.blocks": "double_blocks",
|
|
||||||
"img.attn": "img_attn",
|
|
||||||
"img.mlp": "img_mlp",
|
|
||||||
"img.mod": "img_mod",
|
|
||||||
"txt.attn": "txt_attn",
|
|
||||||
"txt.mlp": "txt_mlp",
|
|
||||||
"txt.mod": "txt_mod",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class GeneralLoRAFromPeft:
|
|
||||||
def __init__(self):
|
|
||||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT]
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_device_dtype_from_state_dict(self, state_dict):
|
|
||||||
device, torch_dtype = None, None
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
device, torch_dtype = param.device, param.dtype
|
|
||||||
break
|
|
||||||
return device, torch_dtype
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
|
|
||||||
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
|
|
||||||
state_dict_ = {}
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_B." not in key:
|
|
||||||
continue
|
|
||||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
|
||||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
|
||||||
if len(weight_up.shape) == 4:
|
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
|
||||||
keys = key.split(".")
|
|
||||||
if len(keys) > keys.index("lora_B") + 2:
|
|
||||||
keys.pop(keys.index("lora_B") + 1)
|
|
||||||
keys.pop(keys.index("lora_B"))
|
|
||||||
target_name = ".".join(keys)
|
|
||||||
if target_name not in target_state_dict:
|
|
||||||
return {}
|
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
|
|
||||||
if len(state_dict_lora) > 0:
|
|
||||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
|
||||||
for name in state_dict_lora:
|
|
||||||
state_dict_model[name] += state_dict_lora[name].to(
|
|
||||||
dtype=state_dict_model[name].dtype,
|
|
||||||
device=state_dict_model[name].device
|
|
||||||
)
|
|
||||||
model.load_state_dict(state_dict_model)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, model, state_dict_lora):
|
|
||||||
for model_class in self.supported_model_classes:
|
|
||||||
if not isinstance(model, model_class):
|
|
||||||
continue
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
try:
|
|
||||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
|
|
||||||
if len(state_dict_lora_) > 0:
|
|
||||||
return "", ""
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT]
|
|
||||||
self.lora_prefix = ["diffusion_model.", "transformer."]
|
|
||||||
self.special_keys = {}
|
|
||||||
|
|
||||||
|
|
||||||
class FluxLoRAConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def align_to_opensource_format(state_dict, alpha=1.0):
|
|
||||||
prefix_rename_dict = {
|
|
||||||
"single_blocks": "lora_unet_single_blocks",
|
|
||||||
"blocks": "lora_unet_double_blocks",
|
|
||||||
}
|
|
||||||
middle_rename_dict = {
|
|
||||||
"norm.linear": "modulation_lin",
|
|
||||||
"to_qkv_mlp": "linear1",
|
|
||||||
"proj_out": "linear2",
|
|
||||||
|
|
||||||
"norm1_a.linear": "img_mod_lin",
|
|
||||||
"norm1_b.linear": "txt_mod_lin",
|
|
||||||
"attn.a_to_qkv": "img_attn_qkv",
|
|
||||||
"attn.b_to_qkv": "txt_attn_qkv",
|
|
||||||
"attn.a_to_out": "img_attn_proj",
|
|
||||||
"attn.b_to_out": "txt_attn_proj",
|
|
||||||
"ff_a.0": "img_mlp_0",
|
|
||||||
"ff_a.2": "img_mlp_2",
|
|
||||||
"ff_b.0": "txt_mlp_0",
|
|
||||||
"ff_b.2": "txt_mlp_2",
|
|
||||||
}
|
|
||||||
suffix_rename_dict = {
|
|
||||||
"lora_B.weight": "lora_up.weight",
|
|
||||||
"lora_A.weight": "lora_down.weight",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
names = name.split(".")
|
|
||||||
if names[-2] != "lora_A" and names[-2] != "lora_B":
|
|
||||||
names.pop(-2)
|
|
||||||
prefix = names[0]
|
|
||||||
middle = ".".join(names[2:-2])
|
|
||||||
suffix = ".".join(names[-2:])
|
|
||||||
block_id = names[1]
|
|
||||||
if middle not in middle_rename_dict:
|
|
||||||
continue
|
|
||||||
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
|
|
||||||
state_dict_[rename] = param
|
|
||||||
if rename.endswith("lora_up.weight"):
|
|
||||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def align_to_diffsynth_format(state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
|
||||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
|
||||||
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
|
||||||
}
|
|
||||||
def guess_block_id(name):
|
|
||||||
names = name.split("_")
|
|
||||||
for i in names:
|
|
||||||
if i.isdigit():
|
|
||||||
return i, name.replace(f"_{i}_", "_blockid_")
|
|
||||||
return None, None
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
block_id, source_name = guess_block_id(name)
|
|
||||||
if source_name in rename_dict:
|
|
||||||
target_name = rename_dict[source_name]
|
|
||||||
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
|
||||||
state_dict_[target_name] = param
|
|
||||||
else:
|
|
||||||
state_dict_[name] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
def get_lora_loaders():
|
|
||||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
|
||||||
@@ -1,441 +0,0 @@
|
|||||||
import os, torch, json, importlib
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
|
|
||||||
|
|
||||||
from .sd_text_encoder import SDTextEncoder
|
|
||||||
from .sd_unet import SDUNet
|
|
||||||
from .sd_vae_encoder import SDVAEEncoder
|
|
||||||
from .sd_vae_decoder import SDVAEDecoder
|
|
||||||
from .lora import get_lora_loaders
|
|
||||||
|
|
||||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|
||||||
from .sdxl_unet import SDXLUNet
|
|
||||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
|
||||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
|
||||||
|
|
||||||
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
|
||||||
from .sd3_dit import SD3DiT
|
|
||||||
from .sd3_vae_decoder import SD3VAEDecoder
|
|
||||||
from .sd3_vae_encoder import SD3VAEEncoder
|
|
||||||
|
|
||||||
from .sd_controlnet import SDControlNet
|
|
||||||
from .sdxl_controlnet import SDXLControlNetUnion
|
|
||||||
|
|
||||||
from .sd_motion import SDMotionModel
|
|
||||||
from .sdxl_motion import SDXLMotionModel
|
|
||||||
|
|
||||||
from .svd_image_encoder import SVDImageEncoder
|
|
||||||
from .svd_unet import SVDUNet
|
|
||||||
from .svd_vae_decoder import SVDVAEDecoder
|
|
||||||
from .svd_vae_encoder import SVDVAEEncoder
|
|
||||||
|
|
||||||
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
|
||||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
|
||||||
|
|
||||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
|
||||||
from .hunyuan_dit import HunyuanDiT
|
|
||||||
from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
|
||||||
from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
|
||||||
|
|
||||||
from .flux_dit import FluxDiT
|
|
||||||
from .flux_text_encoder import FluxTextEncoder2
|
|
||||||
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
|
||||||
from .flux_ipadapter import FluxIpAdapter
|
|
||||||
|
|
||||||
from .cog_vae import CogVAEEncoder, CogVAEDecoder
|
|
||||||
from .cog_dit import CogDiT
|
|
||||||
|
|
||||||
from ..extensions.RIFE import IFNet
|
|
||||||
from ..extensions.ESRGAN import RRDBNet
|
|
||||||
|
|
||||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
|
||||||
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for model_name, model_class in zip(model_names, model_classes):
|
|
||||||
print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
|
||||||
state_dict_converter = model_class.state_dict_converter()
|
|
||||||
if model_resource == "civitai":
|
|
||||||
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
|
||||||
elif model_resource == "diffusers":
|
|
||||||
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
|
||||||
if isinstance(state_dict_results, tuple):
|
|
||||||
model_state_dict, extra_kwargs = state_dict_results
|
|
||||||
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
|
||||||
else:
|
|
||||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
|
||||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
|
||||||
with init_weights_on_device():
|
|
||||||
model= model_class(**extra_kwargs)
|
|
||||||
model.load_state_dict(model_state_dict, assign=True)
|
|
||||||
model = model.to(dtype=torch_dtype, device=device)
|
|
||||||
loaded_model_names.append(model_name)
|
|
||||||
loaded_models.append(model)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for model_name, model_class in zip(model_names, model_classes):
|
|
||||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
|
||||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
|
||||||
model = model.half()
|
|
||||||
try:
|
|
||||||
model = model.to(device=device)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
loaded_model_names.append(model_name)
|
|
||||||
loaded_models.append(model)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
|
||||||
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
|
||||||
base_state_dict = base_model.state_dict()
|
|
||||||
base_model.to("cpu")
|
|
||||||
del base_model
|
|
||||||
model = model_class(**extra_kwargs)
|
|
||||||
model.load_state_dict(base_state_dict, strict=False)
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
model.to(dtype=torch_dtype, device=device)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for model_name, model_class in zip(model_names, model_classes):
|
|
||||||
while True:
|
|
||||||
for model_id in range(len(model_manager.model)):
|
|
||||||
base_model_name = model_manager.model_name[model_id]
|
|
||||||
if base_model_name == model_name:
|
|
||||||
base_model_path = model_manager.model_path[model_id]
|
|
||||||
base_model = model_manager.model[model_id]
|
|
||||||
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
|
||||||
patched_model = load_single_patch_model_from_single_file(
|
|
||||||
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
|
||||||
loaded_model_names.append(base_model_name)
|
|
||||||
loaded_models.append(patched_model)
|
|
||||||
model_manager.model.pop(model_id)
|
|
||||||
model_manager.model_path.pop(model_id)
|
|
||||||
model_manager.model_name.pop(model_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorTemplate:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromSingleFile:
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
self.keys_hash_with_shape_dict = {}
|
|
||||||
self.keys_hash_dict = {}
|
|
||||||
for metadata in model_loader_configs:
|
|
||||||
self.add_model_metadata(*metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
|
||||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
|
||||||
if keys_hash is not None:
|
|
||||||
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if os.path.isdir(file_path):
|
|
||||||
return False
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
return True
|
|
||||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
|
||||||
if keys_hash in self.keys_hash_dict:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
|
|
||||||
# Load models with strict matching
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
|
||||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
# Load models without strict matching
|
|
||||||
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
|
||||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
|
||||||
if keys_hash in self.keys_hash_dict:
|
|
||||||
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
|
||||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
super().__init__(model_loader_configs)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if os.path.isdir(file_path):
|
|
||||||
return False
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
|
||||||
for sub_state_dict in splited_state_dict:
|
|
||||||
if super().match(file_path, sub_state_dict):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
# Split the state_dict and load from each component
|
|
||||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
|
||||||
valid_state_dict = {}
|
|
||||||
for sub_state_dict in splited_state_dict:
|
|
||||||
if super().match(file_path, sub_state_dict):
|
|
||||||
valid_state_dict.update(sub_state_dict)
|
|
||||||
if super().match(file_path, valid_state_dict):
|
|
||||||
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
|
||||||
else:
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
for sub_state_dict in splited_state_dict:
|
|
||||||
if super().match(file_path, sub_state_dict):
|
|
||||||
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
|
||||||
loaded_model_names += loaded_model_names_
|
|
||||||
loaded_models += loaded_models_
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromHuggingfaceFolder:
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
self.architecture_dict = {}
|
|
||||||
for metadata in model_loader_configs:
|
|
||||||
self.add_model_metadata(*metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
|
||||||
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if os.path.isfile(file_path):
|
|
||||||
return False
|
|
||||||
file_list = os.listdir(file_path)
|
|
||||||
if "config.json" not in file_list:
|
|
||||||
return False
|
|
||||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
if "architectures" not in config and "_class_name" not in config:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
||||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
|
|
||||||
for architecture in architectures:
|
|
||||||
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
|
||||||
if redirected_architecture is not None:
|
|
||||||
architecture = redirected_architecture
|
|
||||||
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
|
||||||
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
|
||||||
loaded_model_names += loaded_model_names_
|
|
||||||
loaded_models += loaded_models_
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetectorFromPatchedSingleFile:
|
|
||||||
def __init__(self, model_loader_configs=[]):
|
|
||||||
self.keys_hash_with_shape_dict = {}
|
|
||||||
for metadata in model_loader_configs:
|
|
||||||
self.add_model_metadata(*metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
|
||||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, file_path="", state_dict={}):
|
|
||||||
if os.path.isdir(file_path):
|
|
||||||
return False
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
|
|
||||||
# Load models with strict matching
|
|
||||||
loaded_model_names, loaded_models = [], []
|
|
||||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
||||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
||||||
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
|
||||||
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
|
||||||
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
|
||||||
loaded_model_names += loaded_model_names_
|
|
||||||
loaded_models += loaded_models_
|
|
||||||
return loaded_model_names, loaded_models
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
device="cuda",
|
|
||||||
model_id_list: List[Preset_model_id] = [],
|
|
||||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
|
||||||
file_path_list: List[str] = [],
|
|
||||||
):
|
|
||||||
self.torch_dtype = torch_dtype
|
|
||||||
self.device = device
|
|
||||||
self.model = []
|
|
||||||
self.model_path = []
|
|
||||||
self.model_name = []
|
|
||||||
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
|
|
||||||
self.model_detector = [
|
|
||||||
ModelDetectorFromSingleFile(model_loader_configs),
|
|
||||||
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
|
||||||
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
|
||||||
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
|
||||||
]
|
|
||||||
self.load_models(downloaded_files + file_path_list)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
|
||||||
print(f"Loading models from file: {file_path}")
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
print(f" The following models are loaded: {model_names}.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
|
||||||
print(f"Loading models from folder: {file_path}")
|
|
||||||
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
print(f" The following models are loaded: {model_names}.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
|
||||||
print(f"Loading patch models from file: {file_path}")
|
|
||||||
model_names, models = load_patch_model_from_single_file(
|
|
||||||
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
print(f" The following patched models are loaded: {model_names}.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
|
||||||
if isinstance(file_path, list):
|
|
||||||
for file_path_ in file_path:
|
|
||||||
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
|
||||||
else:
|
|
||||||
print(f"Loading LoRA models from file: {file_path}")
|
|
||||||
if len(state_dict) == 0:
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
|
||||||
for lora in get_lora_loaders():
|
|
||||||
match_results = lora.match(model, state_dict)
|
|
||||||
if match_results is not None:
|
|
||||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
|
||||||
lora_prefix, model_resource = match_results
|
|
||||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
|
||||||
print(f"Loading models from: {file_path}")
|
|
||||||
if device is None: device = self.device
|
|
||||||
if torch_dtype is None: torch_dtype = self.torch_dtype
|
|
||||||
if os.path.isfile(file_path):
|
|
||||||
state_dict = load_state_dict(file_path)
|
|
||||||
else:
|
|
||||||
state_dict = None
|
|
||||||
for model_detector in self.model_detector:
|
|
||||||
if model_detector.match(file_path, state_dict):
|
|
||||||
model_names, models = model_detector.load(
|
|
||||||
file_path, state_dict,
|
|
||||||
device=device, torch_dtype=torch_dtype,
|
|
||||||
allowed_model_names=model_names, model_manager=self
|
|
||||||
)
|
|
||||||
for model_name, model in zip(model_names, models):
|
|
||||||
self.model.append(model)
|
|
||||||
self.model_path.append(file_path)
|
|
||||||
self.model_name.append(model_name)
|
|
||||||
print(f" The following models are loaded: {model_names}.")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(f" We cannot detect the model type. No models are loaded.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
|
|
||||||
for file_path in file_path_list:
|
|
||||||
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
|
||||||
fetched_models = []
|
|
||||||
fetched_model_paths = []
|
|
||||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
|
||||||
if file_path is not None and file_path != model_path:
|
|
||||||
continue
|
|
||||||
if model_name == model_name_:
|
|
||||||
fetched_models.append(model)
|
|
||||||
fetched_model_paths.append(model_path)
|
|
||||||
if len(fetched_models) == 0:
|
|
||||||
print(f"No {model_name} models available.")
|
|
||||||
return None
|
|
||||||
if len(fetched_models) == 1:
|
|
||||||
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
|
||||||
else:
|
|
||||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
|
||||||
if require_model_path:
|
|
||||||
return fetched_models[0], fetched_model_paths[0]
|
|
||||||
else:
|
|
||||||
return fetched_models[0]
|
|
||||||
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
for model in self.model:
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
@@ -1,803 +0,0 @@
|
|||||||
# The code is revised from DiT
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
import math
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
import torch.utils.checkpoint
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
||||||
from transformers import Phi3Config, Phi3Model
|
|
||||||
from transformers.cache_utils import Cache, DynamicCache
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Phi3Transformer(Phi3Model):
|
|
||||||
"""
|
|
||||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
|
|
||||||
We only modified the attention mask
|
|
||||||
Args:
|
|
||||||
config: Phi3Config
|
|
||||||
"""
|
|
||||||
def prefetch_layer(self, layer_idx: int, device: torch.device):
|
|
||||||
"Starts prefetching the next layer cache"
|
|
||||||
with torch.cuda.stream(self.prefetch_stream):
|
|
||||||
# Prefetch next layer tensors to GPU
|
|
||||||
for name, param in self.layers[layer_idx].named_parameters():
|
|
||||||
param.data = param.data.to(device, non_blocking=True)
|
|
||||||
|
|
||||||
def evict_previous_layer(self, layer_idx: int):
|
|
||||||
"Moves the previous layer cache to the CPU"
|
|
||||||
prev_layer_idx = layer_idx - 1
|
|
||||||
for name, param in self.layers[prev_layer_idx].named_parameters():
|
|
||||||
param.data = param.data.to("cpu", non_blocking=True)
|
|
||||||
|
|
||||||
def get_offlaod_layer(self, layer_idx: int, device: torch.device):
|
|
||||||
# init stream
|
|
||||||
if not hasattr(self, "prefetch_stream"):
|
|
||||||
self.prefetch_stream = torch.cuda.Stream()
|
|
||||||
|
|
||||||
# delete previous layer
|
|
||||||
torch.cuda.current_stream().synchronize()
|
|
||||||
self.evict_previous_layer(layer_idx)
|
|
||||||
|
|
||||||
# make sure the current layer is ready
|
|
||||||
torch.cuda.synchronize(self.prefetch_stream)
|
|
||||||
|
|
||||||
# load next layer
|
|
||||||
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
offload_model: Optional[bool] = False,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
|
||||||
return_legacy_cache = False
|
|
||||||
if use_cache and not isinstance(past_key_values, Cache):
|
|
||||||
return_legacy_cache = True
|
|
||||||
if past_key_values is None:
|
|
||||||
past_key_values = DynamicCache()
|
|
||||||
else:
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
logger.warning_once(
|
|
||||||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
|
||||||
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
|
||||||
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# if inputs_embeds is None:
|
|
||||||
# inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
# if cache_position is None:
|
|
||||||
# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
||||||
# cache_position = torch.arange(
|
|
||||||
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
||||||
# )
|
|
||||||
# if position_ids is None:
|
|
||||||
# position_ids = cache_position.unsqueeze(0)
|
|
||||||
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 3:
|
|
||||||
dtype = inputs_embeds.dtype
|
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
attention_mask = (1 - attention_mask) * min_dtype
|
|
||||||
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
|
|
||||||
else:
|
|
||||||
raise Exception("attention_mask parameter was unavailable or invalid")
|
|
||||||
# causal_mask = self._update_causal_mask(
|
|
||||||
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
|
||||||
# )
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
next_decoder_cache = None
|
|
||||||
|
|
||||||
layer_idx = -1
|
|
||||||
for decoder_layer in self.layers:
|
|
||||||
layer_idx += 1
|
|
||||||
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
|
||||||
decoder_layer.__call__,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
position_ids,
|
|
||||||
past_key_values,
|
|
||||||
output_attentions,
|
|
||||||
use_cache,
|
|
||||||
cache_position,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if offload_model and not self.training:
|
|
||||||
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
print('************')
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
|
||||||
if return_legacy_cache:
|
|
||||||
next_cache = next_cache.to_legacy_cache()
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
||||||
return BaseModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def modulate(x, shift, scale):
|
|
||||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbedder(nn.Module):
|
|
||||||
"""
|
|
||||||
Embeds scalar timesteps into vector representations.
|
|
||||||
"""
|
|
||||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
|
||||||
super().__init__()
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
|
||||||
)
|
|
||||||
self.frequency_embedding_size = frequency_embedding_size
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def timestep_embedding(t, dim, max_period=10000):
|
|
||||||
"""
|
|
||||||
Create sinusoidal timestep embeddings.
|
|
||||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
|
||||||
These may be fractional.
|
|
||||||
:param dim: the dimension of the output.
|
|
||||||
:param max_period: controls the minimum frequency of the embeddings.
|
|
||||||
:return: an (N, D) Tensor of positional embeddings.
|
|
||||||
"""
|
|
||||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
|
||||||
half = dim // 2
|
|
||||||
freqs = torch.exp(
|
|
||||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
|
||||||
).to(device=t.device)
|
|
||||||
args = t[:, None].float() * freqs[None]
|
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
||||||
if dim % 2:
|
|
||||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
def forward(self, t, dtype=torch.float32):
|
|
||||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
|
||||||
t_emb = self.mlp(t_freq)
|
|
||||||
return t_emb
|
|
||||||
|
|
||||||
|
|
||||||
class FinalLayer(nn.Module):
|
|
||||||
"""
|
|
||||||
The final layer of DiT.
|
|
||||||
"""
|
|
||||||
def __init__(self, hidden_size, patch_size, out_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
|
||||||
self.adaLN_modulation = nn.Sequential(
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, c):
|
|
||||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
|
||||||
x = modulate(self.norm_final(x), shift, scale)
|
|
||||||
x = self.linear(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
|
|
||||||
"""
|
|
||||||
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
|
||||||
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
|
||||||
"""
|
|
||||||
if isinstance(grid_size, int):
|
|
||||||
grid_size = (grid_size, grid_size)
|
|
||||||
|
|
||||||
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
|
||||||
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
|
||||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
|
||||||
grid = np.stack(grid, axis=0)
|
|
||||||
|
|
||||||
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
|
||||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
|
||||||
if cls_token and extra_tokens > 0:
|
|
||||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
|
||||||
return pos_embed
|
|
||||||
|
|
||||||
|
|
||||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
|
||||||
assert embed_dim % 2 == 0
|
|
||||||
|
|
||||||
# use half of dimensions to encode grid_h
|
|
||||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
|
||||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
|
||||||
|
|
||||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
|
||||||
return emb
|
|
||||||
|
|
||||||
|
|
||||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
||||||
"""
|
|
||||||
embed_dim: output dimension for each position
|
|
||||||
pos: a list of positions to be encoded: size (M,)
|
|
||||||
out: (M, D)
|
|
||||||
"""
|
|
||||||
assert embed_dim % 2 == 0
|
|
||||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
|
||||||
omega /= embed_dim / 2.
|
|
||||||
omega = 1. / 10000**omega # (D/2,)
|
|
||||||
|
|
||||||
pos = pos.reshape(-1) # (M,)
|
|
||||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
|
||||||
|
|
||||||
emb_sin = np.sin(out) # (M, D/2)
|
|
||||||
emb_cos = np.cos(out) # (M, D/2)
|
|
||||||
|
|
||||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
|
||||||
return emb
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbedMR(nn.Module):
|
|
||||||
""" 2D Image to Patch Embedding
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
patch_size: int = 2,
|
|
||||||
in_chans: int = 4,
|
|
||||||
embed_dim: int = 768,
|
|
||||||
bias: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.proj(x)
|
|
||||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class OmniGenOriginalModel(nn.Module):
|
|
||||||
"""
|
|
||||||
Diffusion model with a Transformer backbone.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
transformer_config: Phi3Config,
|
|
||||||
patch_size=2,
|
|
||||||
in_channels=4,
|
|
||||||
pe_interpolation: float = 1.0,
|
|
||||||
pos_embed_max_size: int = 192,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = in_channels
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.pos_embed_max_size = pos_embed_max_size
|
|
||||||
|
|
||||||
hidden_size = transformer_config.hidden_size
|
|
||||||
|
|
||||||
self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
|
||||||
self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
|
||||||
|
|
||||||
self.time_token = TimestepEmbedder(hidden_size)
|
|
||||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
|
||||||
|
|
||||||
self.pe_interpolation = pe_interpolation
|
|
||||||
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
|
|
||||||
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
|
||||||
|
|
||||||
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
|
||||||
|
|
||||||
self.initialize_weights()
|
|
||||||
|
|
||||||
self.llm = Phi3Transformer(config=transformer_config)
|
|
||||||
self.llm.config.use_cache = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, model_name):
|
|
||||||
if not os.path.exists(model_name):
|
|
||||||
cache_folder = os.getenv('HF_HUB_CACHE')
|
|
||||||
model_name = snapshot_download(repo_id=model_name,
|
|
||||||
cache_dir=cache_folder,
|
|
||||||
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
|
||||||
config = Phi3Config.from_pretrained(model_name)
|
|
||||||
model = cls(config)
|
|
||||||
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
|
|
||||||
print("Loading safetensors")
|
|
||||||
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
|
|
||||||
else:
|
|
||||||
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
|
|
||||||
model.load_state_dict(ckpt)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def initialize_weights(self):
|
|
||||||
assert not hasattr(self, "llama")
|
|
||||||
|
|
||||||
# Initialize transformer layers:
|
|
||||||
def _basic_init(module):
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
torch.nn.init.xavier_uniform_(module.weight)
|
|
||||||
if module.bias is not None:
|
|
||||||
nn.init.constant_(module.bias, 0)
|
|
||||||
self.apply(_basic_init)
|
|
||||||
|
|
||||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
|
||||||
w = self.x_embedder.proj.weight.data
|
|
||||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
|
||||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
|
||||||
|
|
||||||
w = self.input_x_embedder.proj.weight.data
|
|
||||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
|
||||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
|
||||||
|
|
||||||
|
|
||||||
# Initialize timestep embedding MLP:
|
|
||||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
|
||||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
|
||||||
nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
|
|
||||||
nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
|
|
||||||
|
|
||||||
# Zero-out output layers:
|
|
||||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
|
||||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
|
||||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
|
||||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
|
||||||
|
|
||||||
def unpatchify(self, x, h, w):
|
|
||||||
"""
|
|
||||||
x: (N, T, patch_size**2 * C)
|
|
||||||
imgs: (N, H, W, C)
|
|
||||||
"""
|
|
||||||
c = self.out_channels
|
|
||||||
|
|
||||||
x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
|
|
||||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
|
||||||
imgs = x.reshape(shape=(x.shape[0], c, h, w))
|
|
||||||
return imgs
|
|
||||||
|
|
||||||
|
|
||||||
def cropped_pos_embed(self, height, width):
|
|
||||||
"""Crops positional embeddings for SD3 compatibility."""
|
|
||||||
if self.pos_embed_max_size is None:
|
|
||||||
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
|
||||||
|
|
||||||
height = height // self.patch_size
|
|
||||||
width = width // self.patch_size
|
|
||||||
if height > self.pos_embed_max_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
|
||||||
)
|
|
||||||
if width > self.pos_embed_max_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
|
||||||
)
|
|
||||||
|
|
||||||
top = (self.pos_embed_max_size - height) // 2
|
|
||||||
left = (self.pos_embed_max_size - width) // 2
|
|
||||||
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
|
||||||
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
|
||||||
# print(top, top + height, left, left + width, spatial_pos_embed.size())
|
|
||||||
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
|
||||||
return spatial_pos_embed
|
|
||||||
|
|
||||||
|
|
||||||
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
|
|
||||||
if isinstance(latents, list):
|
|
||||||
return_list = False
|
|
||||||
if padding_latent is None:
|
|
||||||
padding_latent = [None] * len(latents)
|
|
||||||
return_list = True
|
|
||||||
patched_latents, num_tokens, shapes = [], [], []
|
|
||||||
for latent, padding in zip(latents, padding_latent):
|
|
||||||
height, width = latent.shape[-2:]
|
|
||||||
if is_input_images:
|
|
||||||
latent = self.input_x_embedder(latent)
|
|
||||||
else:
|
|
||||||
latent = self.x_embedder(latent)
|
|
||||||
pos_embed = self.cropped_pos_embed(height, width)
|
|
||||||
latent = latent + pos_embed
|
|
||||||
if padding is not None:
|
|
||||||
latent = torch.cat([latent, padding], dim=-2)
|
|
||||||
patched_latents.append(latent)
|
|
||||||
|
|
||||||
num_tokens.append(pos_embed.size(1))
|
|
||||||
shapes.append([height, width])
|
|
||||||
if not return_list:
|
|
||||||
latents = torch.cat(patched_latents, dim=0)
|
|
||||||
else:
|
|
||||||
latents = patched_latents
|
|
||||||
else:
|
|
||||||
height, width = latents.shape[-2:]
|
|
||||||
if is_input_images:
|
|
||||||
latents = self.input_x_embedder(latents)
|
|
||||||
else:
|
|
||||||
latents = self.x_embedder(latents)
|
|
||||||
pos_embed = self.cropped_pos_embed(height, width)
|
|
||||||
latents = latents + pos_embed
|
|
||||||
num_tokens = latents.size(1)
|
|
||||||
shapes = [height, width]
|
|
||||||
return latents, num_tokens, shapes
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
input_is_list = isinstance(x, list)
|
|
||||||
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
|
||||||
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
|
||||||
|
|
||||||
if input_img_latents is not None:
|
|
||||||
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
|
||||||
if input_ids is not None:
|
|
||||||
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
|
||||||
input_img_inx = 0
|
|
||||||
for b_inx in input_image_sizes.keys():
|
|
||||||
for start_inx, end_inx in input_image_sizes[b_inx]:
|
|
||||||
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
|
||||||
input_img_inx += 1
|
|
||||||
if input_img_latents is not None:
|
|
||||||
assert input_img_inx == len(input_latents)
|
|
||||||
|
|
||||||
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
|
||||||
else:
|
|
||||||
input_emb = torch.cat([time_token, x], dim=1)
|
|
||||||
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
|
||||||
output, past_key_values = output.last_hidden_state, output.past_key_values
|
|
||||||
if input_is_list:
|
|
||||||
image_embedding = output[:, -max(num_tokens):]
|
|
||||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
|
||||||
x = self.final_layer(image_embedding, time_emb)
|
|
||||||
latents = []
|
|
||||||
for i in range(x.size(0)):
|
|
||||||
latent = x[i:i+1, :num_tokens[i]]
|
|
||||||
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
|
||||||
latents.append(latent)
|
|
||||||
else:
|
|
||||||
image_embedding = output[:, -num_tokens:]
|
|
||||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
|
||||||
x = self.final_layer(image_embedding, time_emb)
|
|
||||||
latents = self.unpatchify(x, shapes[0], shapes[1])
|
|
||||||
|
|
||||||
if return_past_key_values:
|
|
||||||
return latents, past_key_values
|
|
||||||
return latents
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
|
||||||
self.llm.config.use_cache = use_kv_cache
|
|
||||||
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
|
|
||||||
if use_img_cfg:
|
|
||||||
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
|
||||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
|
||||||
model_out = [cond, cond, cond]
|
|
||||||
else:
|
|
||||||
cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
|
|
||||||
cond = uncond + cfg_scale * (cond - uncond)
|
|
||||||
model_out = [cond, cond]
|
|
||||||
|
|
||||||
return torch.cat(model_out, dim=0), past_key_values
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
|
||||||
self.llm.config.use_cache = use_kv_cache
|
|
||||||
if past_key_values is None:
|
|
||||||
past_key_values = [None] * len(attention_mask)
|
|
||||||
|
|
||||||
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
|
||||||
timestep = timestep.to(x[0].dtype)
|
|
||||||
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
|
||||||
|
|
||||||
model_out, pask_key_values = [], []
|
|
||||||
for i in range(len(input_ids)):
|
|
||||||
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
|
||||||
model_out.append(temp_out)
|
|
||||||
pask_key_values.append(temp_pask_key_values)
|
|
||||||
|
|
||||||
if len(model_out) == 3:
|
|
||||||
cond, uncond, img_cond = model_out
|
|
||||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
|
||||||
model_out = [cond, cond, cond]
|
|
||||||
elif len(model_out) == 2:
|
|
||||||
cond, uncond = model_out
|
|
||||||
cond = uncond + cfg_scale * (cond - uncond)
|
|
||||||
model_out = [cond, cond]
|
|
||||||
else:
|
|
||||||
return model_out[0]
|
|
||||||
|
|
||||||
return torch.cat(model_out, dim=0), pask_key_values
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OmniGenTransformer(OmniGenOriginalModel):
|
|
||||||
def __init__(self):
|
|
||||||
config = {
|
|
||||||
"_name_or_path": "Phi-3-vision-128k-instruct",
|
|
||||||
"architectures": [
|
|
||||||
"Phi3ForCausalLM"
|
|
||||||
],
|
|
||||||
"attention_dropout": 0.0,
|
|
||||||
"bos_token_id": 1,
|
|
||||||
"eos_token_id": 2,
|
|
||||||
"hidden_act": "silu",
|
|
||||||
"hidden_size": 3072,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"intermediate_size": 8192,
|
|
||||||
"max_position_embeddings": 131072,
|
|
||||||
"model_type": "phi3",
|
|
||||||
"num_attention_heads": 32,
|
|
||||||
"num_hidden_layers": 32,
|
|
||||||
"num_key_value_heads": 32,
|
|
||||||
"original_max_position_embeddings": 4096,
|
|
||||||
"rms_norm_eps": 1e-05,
|
|
||||||
"rope_scaling": {
|
|
||||||
"long_factor": [
|
|
||||||
1.0299999713897705,
|
|
||||||
1.0499999523162842,
|
|
||||||
1.0499999523162842,
|
|
||||||
1.0799999237060547,
|
|
||||||
1.2299998998641968,
|
|
||||||
1.2299998998641968,
|
|
||||||
1.2999999523162842,
|
|
||||||
1.4499999284744263,
|
|
||||||
1.5999999046325684,
|
|
||||||
1.6499998569488525,
|
|
||||||
1.8999998569488525,
|
|
||||||
2.859999895095825,
|
|
||||||
3.68999981880188,
|
|
||||||
5.419999599456787,
|
|
||||||
5.489999771118164,
|
|
||||||
5.489999771118164,
|
|
||||||
9.09000015258789,
|
|
||||||
11.579999923706055,
|
|
||||||
15.65999984741211,
|
|
||||||
15.769999504089355,
|
|
||||||
15.789999961853027,
|
|
||||||
18.360000610351562,
|
|
||||||
21.989999771118164,
|
|
||||||
23.079999923706055,
|
|
||||||
30.009998321533203,
|
|
||||||
32.35000228881836,
|
|
||||||
32.590003967285156,
|
|
||||||
35.56000518798828,
|
|
||||||
39.95000457763672,
|
|
||||||
53.840003967285156,
|
|
||||||
56.20000457763672,
|
|
||||||
57.95000457763672,
|
|
||||||
59.29000473022461,
|
|
||||||
59.77000427246094,
|
|
||||||
59.920005798339844,
|
|
||||||
61.190006256103516,
|
|
||||||
61.96000671386719,
|
|
||||||
62.50000762939453,
|
|
||||||
63.3700065612793,
|
|
||||||
63.48000717163086,
|
|
||||||
63.48000717163086,
|
|
||||||
63.66000747680664,
|
|
||||||
63.850006103515625,
|
|
||||||
64.08000946044922,
|
|
||||||
64.760009765625,
|
|
||||||
64.80001068115234,
|
|
||||||
64.81001281738281,
|
|
||||||
64.81001281738281
|
|
||||||
],
|
|
||||||
"short_factor": [
|
|
||||||
1.05,
|
|
||||||
1.05,
|
|
||||||
1.05,
|
|
||||||
1.1,
|
|
||||||
1.1,
|
|
||||||
1.1,
|
|
||||||
1.2500000000000002,
|
|
||||||
1.2500000000000002,
|
|
||||||
1.4000000000000004,
|
|
||||||
1.4500000000000004,
|
|
||||||
1.5500000000000005,
|
|
||||||
1.8500000000000008,
|
|
||||||
1.9000000000000008,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.000000000000001,
|
|
||||||
2.1000000000000005,
|
|
||||||
2.1000000000000005,
|
|
||||||
2.2,
|
|
||||||
2.3499999999999996,
|
|
||||||
2.3499999999999996,
|
|
||||||
2.3499999999999996,
|
|
||||||
2.3499999999999996,
|
|
||||||
2.3999999999999995,
|
|
||||||
2.3999999999999995,
|
|
||||||
2.6499999999999986,
|
|
||||||
2.6999999999999984,
|
|
||||||
2.8999999999999977,
|
|
||||||
2.9499999999999975,
|
|
||||||
3.049999999999997,
|
|
||||||
3.049999999999997,
|
|
||||||
3.049999999999997
|
|
||||||
],
|
|
||||||
"type": "su"
|
|
||||||
},
|
|
||||||
"rope_theta": 10000.0,
|
|
||||||
"sliding_window": 131072,
|
|
||||||
"tie_word_embeddings": False,
|
|
||||||
"torch_dtype": "bfloat16",
|
|
||||||
"transformers_version": "4.38.1",
|
|
||||||
"use_cache": True,
|
|
||||||
"vocab_size": 32064,
|
|
||||||
"_attn_implementation": "sdpa"
|
|
||||||
}
|
|
||||||
config = Phi3Config(**config)
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
|
||||||
input_is_list = isinstance(x, list)
|
|
||||||
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
|
||||||
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
|
||||||
|
|
||||||
if input_img_latents is not None:
|
|
||||||
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
|
||||||
if input_ids is not None:
|
|
||||||
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
|
||||||
input_img_inx = 0
|
|
||||||
for b_inx in input_image_sizes.keys():
|
|
||||||
for start_inx, end_inx in input_image_sizes[b_inx]:
|
|
||||||
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
|
||||||
input_img_inx += 1
|
|
||||||
if input_img_latents is not None:
|
|
||||||
assert input_img_inx == len(input_latents)
|
|
||||||
|
|
||||||
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
|
||||||
else:
|
|
||||||
input_emb = torch.cat([time_token, x], dim=1)
|
|
||||||
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
|
||||||
output, past_key_values = output.last_hidden_state, output.past_key_values
|
|
||||||
if input_is_list:
|
|
||||||
image_embedding = output[:, -max(num_tokens):]
|
|
||||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
|
||||||
x = self.final_layer(image_embedding, time_emb)
|
|
||||||
latents = []
|
|
||||||
for i in range(x.size(0)):
|
|
||||||
latent = x[i:i+1, :num_tokens[i]]
|
|
||||||
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
|
||||||
latents.append(latent)
|
|
||||||
else:
|
|
||||||
image_embedding = output[:, -num_tokens:]
|
|
||||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
|
||||||
x = self.final_layer(image_embedding, time_emb)
|
|
||||||
latents = self.unpatchify(x, shapes[0], shapes[1])
|
|
||||||
|
|
||||||
if return_past_key_values:
|
|
||||||
return latents, past_key_values
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
|
||||||
self.llm.config.use_cache = use_kv_cache
|
|
||||||
if past_key_values is None:
|
|
||||||
past_key_values = [None] * len(attention_mask)
|
|
||||||
|
|
||||||
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
|
||||||
timestep = timestep.to(x[0].dtype)
|
|
||||||
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
|
||||||
|
|
||||||
model_out, pask_key_values = [], []
|
|
||||||
for i in range(len(input_ids)):
|
|
||||||
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
|
||||||
model_out.append(temp_out)
|
|
||||||
pask_key_values.append(temp_pask_key_values)
|
|
||||||
|
|
||||||
if len(model_out) == 3:
|
|
||||||
cond, uncond, img_cond = model_out
|
|
||||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
|
||||||
model_out = [cond, cond, cond]
|
|
||||||
elif len(model_out) == 2:
|
|
||||||
cond, uncond = model_out
|
|
||||||
cond = uncond + cfg_scale * (cond - uncond)
|
|
||||||
model_out = [cond, cond]
|
|
||||||
else:
|
|
||||||
return model_out[0]
|
|
||||||
|
|
||||||
return torch.cat(model_out, dim=0), pask_key_values
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return OmniGenTransformerStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OmniGenTransformerStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return state_dict
|
|
||||||
@@ -1,551 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange
|
|
||||||
from .svd_unet import TemporalTimesteps
|
|
||||||
from .tiler import TileWorker
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, dim, eps, elementwise_affine=True):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
if elementwise_affine:
|
|
||||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
|
||||||
else:
|
|
||||||
self.weight = None
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
|
||||||
hidden_states = hidden_states.to(input_dtype)
|
|
||||||
if self.weight is not None:
|
|
||||||
hidden_states = hidden_states * self.weight
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbed(torch.nn.Module):
|
|
||||||
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
|
|
||||||
super().__init__()
|
|
||||||
self.pos_embed_max_size = pos_embed_max_size
|
|
||||||
self.patch_size = patch_size
|
|
||||||
|
|
||||||
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
|
|
||||||
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, embed_dim))
|
|
||||||
|
|
||||||
def cropped_pos_embed(self, height, width):
|
|
||||||
height = height // self.patch_size
|
|
||||||
width = width // self.patch_size
|
|
||||||
top = (self.pos_embed_max_size - height) // 2
|
|
||||||
left = (self.pos_embed_max_size - width) // 2
|
|
||||||
spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
|
|
||||||
return spatial_pos_embed
|
|
||||||
|
|
||||||
def forward(self, latent):
|
|
||||||
height, width = latent.shape[-2:]
|
|
||||||
latent = self.proj(latent)
|
|
||||||
latent = latent.flatten(2).transpose(1, 2)
|
|
||||||
pos_embed = self.cropped_pos_embed(height, width)
|
|
||||||
return latent + pos_embed
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbeddings(torch.nn.Module):
|
|
||||||
def __init__(self, dim_in, dim_out, computation_device=None):
|
|
||||||
super().__init__()
|
|
||||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
|
||||||
self.timestep_embedder = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, timestep, dtype):
|
|
||||||
time_emb = self.time_proj(timestep).to(dtype)
|
|
||||||
time_emb = self.timestep_embedder(time_emb)
|
|
||||||
return time_emb
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AdaLayerNorm(torch.nn.Module):
|
|
||||||
def __init__(self, dim, single=False, dual=False):
|
|
||||||
super().__init__()
|
|
||||||
self.single = single
|
|
||||||
self.dual = dual
|
|
||||||
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
|
|
||||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
||||||
|
|
||||||
def forward(self, x, emb):
|
|
||||||
emb = self.linear(torch.nn.functional.silu(emb))
|
|
||||||
if self.single:
|
|
||||||
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
|
||||||
x = self.norm(x) * (1 + scale) + shift
|
|
||||||
return x
|
|
||||||
elif self.dual:
|
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
|
|
||||||
norm_x = self.norm(x)
|
|
||||||
x = norm_x * (1 + scale_msa) + shift_msa
|
|
||||||
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
|
|
||||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
|
|
||||||
else:
|
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
|
||||||
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
|
||||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class JointAttention(torch.nn.Module):
|
|
||||||
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False, use_rms_norm=False):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.only_out_a = only_out_a
|
|
||||||
|
|
||||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
|
||||||
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
|
||||||
|
|
||||||
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
|
||||||
if not only_out_a:
|
|
||||||
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
|
||||||
|
|
||||||
if use_rms_norm:
|
|
||||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
|
||||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
|
||||||
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
|
|
||||||
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
|
|
||||||
else:
|
|
||||||
self.norm_q_a = None
|
|
||||||
self.norm_k_a = None
|
|
||||||
self.norm_q_b = None
|
|
||||||
self.norm_k_b = None
|
|
||||||
|
|
||||||
|
|
||||||
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
|
|
||||||
batch_size = hidden_states.shape[0]
|
|
||||||
qkv = to_qkv(hidden_states)
|
|
||||||
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
q, k, v = qkv.chunk(3, dim=1)
|
|
||||||
if norm_q is not None:
|
|
||||||
q = norm_q(q)
|
|
||||||
if norm_k is not None:
|
|
||||||
k = norm_k(k)
|
|
||||||
return q, k, v
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b):
|
|
||||||
batch_size = hidden_states_a.shape[0]
|
|
||||||
|
|
||||||
qa, ka, va = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
|
|
||||||
qb, kb, vb = self.process_qkv(hidden_states_b, self.b_to_qkv, self.norm_q_b, self.norm_k_b)
|
|
||||||
q = torch.concat([qa, qb], dim=2)
|
|
||||||
k = torch.concat([ka, kb], dim=2)
|
|
||||||
v = torch.concat([va, vb], dim=2)
|
|
||||||
|
|
||||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
|
||||||
hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
|
|
||||||
hidden_states_a = self.a_to_out(hidden_states_a)
|
|
||||||
if self.only_out_a:
|
|
||||||
return hidden_states_a
|
|
||||||
else:
|
|
||||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
|
||||||
return hidden_states_a, hidden_states_b
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SingleAttention(torch.nn.Module):
|
|
||||||
def __init__(self, dim_a, num_heads, head_dim, use_rms_norm=False):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
|
|
||||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
|
||||||
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
|
||||||
|
|
||||||
if use_rms_norm:
|
|
||||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
|
||||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
|
||||||
else:
|
|
||||||
self.norm_q_a = None
|
|
||||||
self.norm_k_a = None
|
|
||||||
|
|
||||||
|
|
||||||
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
|
|
||||||
batch_size = hidden_states.shape[0]
|
|
||||||
qkv = to_qkv(hidden_states)
|
|
||||||
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
q, k, v = qkv.chunk(3, dim=1)
|
|
||||||
if norm_q is not None:
|
|
||||||
q = norm_q(q)
|
|
||||||
if norm_k is not None:
|
|
||||||
k = norm_k(k)
|
|
||||||
return q, k, v
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a):
|
|
||||||
batch_size = hidden_states_a.shape[0]
|
|
||||||
q, k, v = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
|
|
||||||
|
|
||||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
|
||||||
hidden_states = self.a_to_out(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DualTransformerBlock(torch.nn.Module):
|
|
||||||
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1_a = AdaLayerNorm(dim, dual=True)
|
|
||||||
self.norm1_b = AdaLayerNorm(dim)
|
|
||||||
|
|
||||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
|
||||||
self.attn2 = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
|
||||||
|
|
||||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.ff_a = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim, dim*4),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(dim*4, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.ff_b = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim, dim*4),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(dim*4, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
|
||||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
|
|
||||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
|
||||||
|
|
||||||
# Attention
|
|
||||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
|
||||||
|
|
||||||
# Part A
|
|
||||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
|
||||||
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
|
|
||||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
|
||||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
|
||||||
|
|
||||||
# Part B
|
|
||||||
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
|
||||||
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
|
||||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
|
||||||
|
|
||||||
return hidden_states_a, hidden_states_b
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class JointTransformerBlock(torch.nn.Module):
|
|
||||||
def __init__(self, dim, num_attention_heads, use_rms_norm=False, dual=False):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1_a = AdaLayerNorm(dim, dual=dual)
|
|
||||||
self.norm1_b = AdaLayerNorm(dim)
|
|
||||||
|
|
||||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
|
||||||
if dual:
|
|
||||||
self.attn2 = SingleAttention(dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
|
||||||
|
|
||||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.ff_a = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim, dim*4),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(dim*4, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.ff_b = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim, dim*4),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(dim*4, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
|
||||||
if self.norm1_a.dual:
|
|
||||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
|
|
||||||
else:
|
|
||||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
|
||||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
|
||||||
|
|
||||||
# Attention
|
|
||||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
|
||||||
|
|
||||||
# Part A
|
|
||||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
|
||||||
if self.norm1_a.dual:
|
|
||||||
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
|
|
||||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
|
||||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
|
||||||
|
|
||||||
# Part B
|
|
||||||
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
|
||||||
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
|
||||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
|
||||||
|
|
||||||
return hidden_states_a, hidden_states_b
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class JointTransformerFinalBlock(torch.nn.Module):
|
|
||||||
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1_a = AdaLayerNorm(dim)
|
|
||||||
self.norm1_b = AdaLayerNorm(dim, single=True)
|
|
||||||
|
|
||||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True, use_rms_norm=use_rms_norm)
|
|
||||||
|
|
||||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.ff_a = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(dim, dim*4),
|
|
||||||
torch.nn.GELU(approximate="tanh"),
|
|
||||||
torch.nn.Linear(dim*4, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
|
||||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
|
||||||
norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
|
|
||||||
|
|
||||||
# Attention
|
|
||||||
attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
|
||||||
|
|
||||||
# Part A
|
|
||||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
|
||||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
|
||||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
|
||||||
|
|
||||||
return hidden_states_a, hidden_states_b
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SD3DiT(torch.nn.Module):
|
|
||||||
def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False, num_dual_blocks=0, pos_embed_max_size=192):
|
|
||||||
super().__init__()
|
|
||||||
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=pos_embed_max_size)
|
|
||||||
self.time_embedder = TimestepEmbeddings(256, embed_dim)
|
|
||||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, embed_dim), torch.nn.SiLU(), torch.nn.Linear(embed_dim, embed_dim))
|
|
||||||
self.context_embedder = torch.nn.Linear(4096, embed_dim)
|
|
||||||
self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm, dual=True) for _ in range(num_dual_blocks)]
|
|
||||||
+ [JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1-num_dual_blocks)]
|
|
||||||
+ [JointTransformerFinalBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm)])
|
|
||||||
self.norm_out = AdaLayerNorm(embed_dim, single=True)
|
|
||||||
self.proj_out = torch.nn.Linear(embed_dim, 64)
|
|
||||||
|
|
||||||
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
|
|
||||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
|
||||||
hidden_states = TileWorker().tiled_forward(
|
|
||||||
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
|
|
||||||
hidden_states,
|
|
||||||
tile_size,
|
|
||||||
tile_stride,
|
|
||||||
tile_device=hidden_states.device,
|
|
||||||
tile_dtype=hidden_states.dtype
|
|
||||||
)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
|
|
||||||
if tiled:
|
|
||||||
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
|
|
||||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
|
||||||
prompt_emb = self.context_embedder(prompt_emb)
|
|
||||||
|
|
||||||
height, width = hidden_states.shape[-2:]
|
|
||||||
hidden_states = self.pos_embedder(hidden_states)
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
if self.training and use_gradient_checkpointing:
|
|
||||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states, prompt_emb, conditioning,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
|
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, conditioning)
|
|
||||||
hidden_states = self.proj_out(hidden_states)
|
|
||||||
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return SD3DiTStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SD3DiTStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def infer_architecture(self, state_dict):
|
|
||||||
embed_dim = state_dict["blocks.0.ff_a.0.weight"].shape[1]
|
|
||||||
num_layers = 100
|
|
||||||
while num_layers > 0 and f"blocks.{num_layers-1}.ff_a.0.bias" not in state_dict:
|
|
||||||
num_layers -= 1
|
|
||||||
use_rms_norm = "blocks.0.attn.norm_q_a.weight" in state_dict
|
|
||||||
num_dual_blocks = 0
|
|
||||||
while f"blocks.{num_dual_blocks}.attn2.a_to_out.bias" in state_dict:
|
|
||||||
num_dual_blocks += 1
|
|
||||||
pos_embed_max_size = state_dict["pos_embedder.pos_embed"].shape[1]
|
|
||||||
return {
|
|
||||||
"embed_dim": embed_dim,
|
|
||||||
"num_layers": num_layers,
|
|
||||||
"use_rms_norm": use_rms_norm,
|
|
||||||
"num_dual_blocks": num_dual_blocks,
|
|
||||||
"pos_embed_max_size": pos_embed_max_size
|
|
||||||
}
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"context_embedder": "context_embedder",
|
|
||||||
"pos_embed.pos_embed": "pos_embedder.pos_embed",
|
|
||||||
"pos_embed.proj": "pos_embedder.proj",
|
|
||||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
|
||||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
|
||||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
|
||||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
|
||||||
"norm_out.linear": "norm_out.linear",
|
|
||||||
"proj_out": "proj_out",
|
|
||||||
|
|
||||||
"norm1.linear": "norm1_a.linear",
|
|
||||||
"norm1_context.linear": "norm1_b.linear",
|
|
||||||
"attn.to_q": "attn.a_to_q",
|
|
||||||
"attn.to_k": "attn.a_to_k",
|
|
||||||
"attn.to_v": "attn.a_to_v",
|
|
||||||
"attn.to_out.0": "attn.a_to_out",
|
|
||||||
"attn.add_q_proj": "attn.b_to_q",
|
|
||||||
"attn.add_k_proj": "attn.b_to_k",
|
|
||||||
"attn.add_v_proj": "attn.b_to_v",
|
|
||||||
"attn.to_add_out": "attn.b_to_out",
|
|
||||||
"ff.net.0.proj": "ff_a.0",
|
|
||||||
"ff.net.2": "ff_a.2",
|
|
||||||
"ff_context.net.0.proj": "ff_b.0",
|
|
||||||
"ff_context.net.2": "ff_b.2",
|
|
||||||
|
|
||||||
"attn.norm_q": "attn.norm_q_a",
|
|
||||||
"attn.norm_k": "attn.norm_k_a",
|
|
||||||
"attn.norm_added_q": "attn.norm_q_b",
|
|
||||||
"attn.norm_added_k": "attn.norm_k_b",
|
|
||||||
}
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name in rename_dict:
|
|
||||||
if name == "pos_embed.pos_embed":
|
|
||||||
param = param.reshape((1, 192, 192, param.shape[-1]))
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
elif name.endswith(".weight") or name.endswith(".bias"):
|
|
||||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
|
||||||
prefix = name[:-len(suffix)]
|
|
||||||
if prefix in rename_dict:
|
|
||||||
state_dict_[rename_dict[prefix] + suffix] = param
|
|
||||||
elif prefix.startswith("transformer_blocks."):
|
|
||||||
names = prefix.split(".")
|
|
||||||
names[0] = "blocks"
|
|
||||||
middle = ".".join(names[2:])
|
|
||||||
if middle in rename_dict:
|
|
||||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
|
||||||
state_dict_[name_] = param
|
|
||||||
merged_keys = [name for name in state_dict_ if ".a_to_q." in name or ".b_to_q." in name]
|
|
||||||
for key in merged_keys:
|
|
||||||
param = torch.concat([
|
|
||||||
state_dict_[key.replace("to_q", "to_q")],
|
|
||||||
state_dict_[key.replace("to_q", "to_k")],
|
|
||||||
state_dict_[key.replace("to_q", "to_v")],
|
|
||||||
], dim=0)
|
|
||||||
name = key.replace("to_q", "to_qkv")
|
|
||||||
state_dict_.pop(key.replace("to_q", "to_q"))
|
|
||||||
state_dict_.pop(key.replace("to_q", "to_k"))
|
|
||||||
state_dict_.pop(key.replace("to_q", "to_v"))
|
|
||||||
state_dict_[name] = param
|
|
||||||
return state_dict_, self.infer_architecture(state_dict_)
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
rename_dict = {
|
|
||||||
"model.diffusion_model.context_embedder.bias": "context_embedder.bias",
|
|
||||||
"model.diffusion_model.context_embedder.weight": "context_embedder.weight",
|
|
||||||
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
|
|
||||||
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
|
|
||||||
|
|
||||||
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
|
|
||||||
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
|
|
||||||
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
|
|
||||||
"model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
|
|
||||||
"model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
|
|
||||||
"model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
|
|
||||||
"model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
|
|
||||||
"model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
|
|
||||||
"model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
|
|
||||||
"model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
|
|
||||||
"model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
|
|
||||||
|
|
||||||
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
|
|
||||||
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
|
|
||||||
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
|
||||||
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
|
||||||
}
|
|
||||||
for i in range(40):
|
|
||||||
rename_dict.update({
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_b.linear.bias",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_b.linear.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.bias": f"blocks.{i}.attn.b_to_out.bias",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.weight": f"blocks.{i}.attn.b_to_out.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.bias": [f'blocks.{i}.attn.b_to_q.bias', f'blocks.{i}.attn.b_to_k.bias', f'blocks.{i}.attn.b_to_v.bias'],
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.weight": [f'blocks.{i}.attn.b_to_q.weight', f'blocks.{i}.attn.b_to_k.weight', f'blocks.{i}.attn.b_to_v.weight'],
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.bias": f"blocks.{i}.ff_b.0.bias",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.weight": f"blocks.{i}.ff_b.0.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.bias": f"blocks.{i}.ff_b.2.bias",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.weight": f"blocks.{i}.ff_b.2.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_a.linear.bias",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_a.linear.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.bias": f"blocks.{i}.attn.a_to_out.bias",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.weight": f"blocks.{i}.attn.a_to_out.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.bias": [f'blocks.{i}.attn.a_to_q.bias', f'blocks.{i}.attn.a_to_k.bias', f'blocks.{i}.attn.a_to_v.bias'],
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.weight": [f'blocks.{i}.attn.a_to_q.weight', f'blocks.{i}.attn.a_to_k.weight', f'blocks.{i}.attn.a_to_v.weight'],
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.bias": f"blocks.{i}.ff_a.0.bias",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.weight": f"blocks.{i}.ff_a.0.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.bias": f"blocks.{i}.ff_a.2.bias",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.weight": f"blocks.{i}.ff_a.2.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_a.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_a.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_b.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_b.weight",
|
|
||||||
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_q.weight": f"blocks.{i}.attn2.norm_q_a.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_k.weight": f"blocks.{i}.attn2.norm_k_a.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.weight": f"blocks.{i}.attn2.a_to_qkv.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.bias": f"blocks.{i}.attn2.a_to_qkv.bias",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.weight": f"blocks.{i}.attn2.a_to_out.weight",
|
|
||||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.bias": f"blocks.{i}.attn2.a_to_out.bias",
|
|
||||||
})
|
|
||||||
state_dict_ = {}
|
|
||||||
for name in state_dict:
|
|
||||||
if name in rename_dict:
|
|
||||||
param = state_dict[name]
|
|
||||||
if name == "model.diffusion_model.pos_embed":
|
|
||||||
pos_embed_max_size = int(param.shape[1] ** 0.5 + 0.4)
|
|
||||||
param = param.reshape((1, pos_embed_max_size, pos_embed_max_size, param.shape[-1]))
|
|
||||||
if isinstance(rename_dict[name], str):
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
else:
|
|
||||||
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
|
|
||||||
state_dict_[name_] = param
|
|
||||||
extra_kwargs = self.infer_architecture(state_dict_)
|
|
||||||
num_layers = extra_kwargs["num_layers"]
|
|
||||||
for name in [
|
|
||||||
f"blocks.{num_layers-1}.norm1_b.linear.weight", f"blocks.{num_layers-1}.norm1_b.linear.bias", "norm_out.linear.weight", "norm_out.linear.bias",
|
|
||||||
]:
|
|
||||||
param = state_dict_[name]
|
|
||||||
dim = param.shape[0] // 2
|
|
||||||
param = torch.concat([param[dim:], param[:dim]], axis=0)
|
|
||||||
state_dict_[name] = param
|
|
||||||
return state_dict_, self.infer_architecture(state_dict_)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,81 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
|
|
||||||
from .sd_unet import ResnetBlock, UpSampler
|
|
||||||
from .tiler import TileWorker
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SD3VAEDecoder(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.scaling_factor = 1.5305 # Different from SD 1.x
|
|
||||||
self.shift_factor = 0.0609 # Different from SD 1.x
|
|
||||||
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([
|
|
||||||
# UNetMidBlock2D
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
# UpDecoderBlock2D
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
UpSampler(512),
|
|
||||||
# UpDecoderBlock2D
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
UpSampler(512),
|
|
||||||
# UpDecoderBlock2D
|
|
||||||
ResnetBlock(512, 256, eps=1e-6),
|
|
||||||
ResnetBlock(256, 256, eps=1e-6),
|
|
||||||
ResnetBlock(256, 256, eps=1e-6),
|
|
||||||
UpSampler(256),
|
|
||||||
# UpDecoderBlock2D
|
|
||||||
ResnetBlock(256, 128, eps=1e-6),
|
|
||||||
ResnetBlock(128, 128, eps=1e-6),
|
|
||||||
ResnetBlock(128, 128, eps=1e-6),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
|
||||||
self.conv_act = torch.nn.SiLU()
|
|
||||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
|
||||||
hidden_states = TileWorker().tiled_forward(
|
|
||||||
lambda x: self.forward(x),
|
|
||||||
sample,
|
|
||||||
tile_size,
|
|
||||||
tile_stride,
|
|
||||||
tile_device=sample.device,
|
|
||||||
tile_dtype=sample.dtype
|
|
||||||
)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
|
||||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
|
||||||
if tiled:
|
|
||||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
# 1. pre-process
|
|
||||||
hidden_states = sample / self.scaling_factor + self.shift_factor
|
|
||||||
hidden_states = self.conv_in(hidden_states)
|
|
||||||
time_emb = None
|
|
||||||
text_emb = None
|
|
||||||
res_stack = None
|
|
||||||
|
|
||||||
# 2. blocks
|
|
||||||
for i, block in enumerate(self.blocks):
|
|
||||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
|
||||||
|
|
||||||
# 3. output
|
|
||||||
hidden_states = self.conv_norm_out(hidden_states)
|
|
||||||
hidden_states = self.conv_act(hidden_states)
|
|
||||||
hidden_states = self.conv_out(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return SDVAEDecoderStateDictConverter()
|
|
||||||
@@ -1,95 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .sd_unet import ResnetBlock, DownSampler
|
|
||||||
from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
|
|
||||||
from .tiler import TileWorker
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
|
|
||||||
class SD3VAEEncoder(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.scaling_factor = 1.5305 # Different from SD 1.x
|
|
||||||
self.shift_factor = 0.0609 # Different from SD 1.x
|
|
||||||
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([
|
|
||||||
# DownEncoderBlock2D
|
|
||||||
ResnetBlock(128, 128, eps=1e-6),
|
|
||||||
ResnetBlock(128, 128, eps=1e-6),
|
|
||||||
DownSampler(128, padding=0, extra_padding=True),
|
|
||||||
# DownEncoderBlock2D
|
|
||||||
ResnetBlock(128, 256, eps=1e-6),
|
|
||||||
ResnetBlock(256, 256, eps=1e-6),
|
|
||||||
DownSampler(256, padding=0, extra_padding=True),
|
|
||||||
# DownEncoderBlock2D
|
|
||||||
ResnetBlock(256, 512, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
DownSampler(512, padding=0, extra_padding=True),
|
|
||||||
# DownEncoderBlock2D
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
# UNetMidBlock2D
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
|
||||||
ResnetBlock(512, 512, eps=1e-6),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
|
||||||
self.conv_act = torch.nn.SiLU()
|
|
||||||
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
|
||||||
hidden_states = TileWorker().tiled_forward(
|
|
||||||
lambda x: self.forward(x),
|
|
||||||
sample,
|
|
||||||
tile_size,
|
|
||||||
tile_stride,
|
|
||||||
tile_device=sample.device,
|
|
||||||
tile_dtype=sample.dtype
|
|
||||||
)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
|
||||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
|
||||||
if tiled:
|
|
||||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
# 1. pre-process
|
|
||||||
hidden_states = self.conv_in(sample)
|
|
||||||
time_emb = None
|
|
||||||
text_emb = None
|
|
||||||
res_stack = None
|
|
||||||
|
|
||||||
# 2. blocks
|
|
||||||
for i, block in enumerate(self.blocks):
|
|
||||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
|
||||||
|
|
||||||
# 3. output
|
|
||||||
hidden_states = self.conv_norm_out(hidden_states)
|
|
||||||
hidden_states = self.conv_act(hidden_states)
|
|
||||||
hidden_states = self.conv_out(hidden_states)
|
|
||||||
hidden_states = hidden_states[:, :16]
|
|
||||||
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def encode_video(self, sample, batch_size=8):
|
|
||||||
B = sample.shape[0]
|
|
||||||
hidden_states = []
|
|
||||||
|
|
||||||
for i in range(0, sample.shape[2], batch_size):
|
|
||||||
|
|
||||||
j = min(i + batch_size, sample.shape[2])
|
|
||||||
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
|
||||||
|
|
||||||
hidden_states_batch = self(sample_batch)
|
|
||||||
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
|
||||||
|
|
||||||
hidden_states.append(hidden_states_batch)
|
|
||||||
|
|
||||||
hidden_states = torch.concat(hidden_states, dim=2)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return SDVAEEncoderStateDictConverter()
|
|
||||||
@@ -97,10 +97,9 @@ class SDControlNet(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
sample, timestep, encoder_hidden_states, conditioning,
|
sample, timestep, encoder_hidden_states, conditioning,
|
||||||
tiled=False, tile_size=64, tile_stride=32,
|
tiled=False, tile_size=64, tile_stride=32,
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
# 1. time
|
# 1. time
|
||||||
time_emb = self.time_proj(timestep).to(sample.dtype)
|
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||||
time_emb = self.time_embedding(time_emb)
|
time_emb = self.time_embedding(time_emb)
|
||||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||||
|
|
||||||
@@ -135,8 +134,7 @@ class SDControlNet(torch.nn.Module):
|
|||||||
|
|
||||||
return controlnet_res_stack
|
return controlnet_res_stack
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDControlNetStateDictConverter()
|
return SDControlNetStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class SDIpAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
def set_less_adapter(self):
|
def set_less_adapter(self):
|
||||||
# IP-Adapter for SD v1.5 doesn't support this feature.
|
# IP-Adapter for SD v1.5 doesn't support this feature.
|
||||||
self.set_full_adapter()
|
self.set_full_adapter(self)
|
||||||
|
|
||||||
def forward(self, hidden_states, scale=1.0):
|
def forward(self, hidden_states, scale=1.0):
|
||||||
hidden_states = self.image_proj(hidden_states)
|
hidden_states = self.image_proj(hidden_states)
|
||||||
@@ -47,8 +47,7 @@ class SDIpAdapter(torch.nn.Module):
|
|||||||
}
|
}
|
||||||
return ip_kv_dict
|
return ip_kv_dict
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDIpAdapterStateDictConverter()
|
return SDIpAdapterStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
60
diffsynth/models/sd_lora.py
Normal file
60
diffsynth/models/sd_lora.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
from .sd_unet import SDUNetStateDictConverter, SDUNet
|
||||||
|
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class SDLoRA:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
|
||||||
|
special_keys = {
|
||||||
|
"down.blocks": "down_blocks",
|
||||||
|
"up.blocks": "up_blocks",
|
||||||
|
"mid.block": "mid_block",
|
||||||
|
"proj.in": "proj_in",
|
||||||
|
"proj.out": "proj_out",
|
||||||
|
"transformer.blocks": "transformer_blocks",
|
||||||
|
"to.q": "to_q",
|
||||||
|
"to.k": "to_k",
|
||||||
|
"to.v": "to_v",
|
||||||
|
"to.out": "to_out",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if ".lora_up" not in key:
|
||||||
|
continue
|
||||||
|
if not key.startswith(lora_prefix):
|
||||||
|
continue
|
||||||
|
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
||||||
|
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
|
||||||
|
for special_key in special_keys:
|
||||||
|
target_name = target_name.replace(special_key, special_keys[special_key])
|
||||||
|
state_dict_[target_name] = lora_weight.cpu()
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
|
||||||
|
state_dict_unet = unet.state_dict()
|
||||||
|
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
|
||||||
|
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
|
||||||
|
if len(state_dict_lora) > 0:
|
||||||
|
for name in state_dict_lora:
|
||||||
|
state_dict_unet[name] += state_dict_lora[name].to(device=device)
|
||||||
|
unet.load_state_dict(state_dict_unet)
|
||||||
|
|
||||||
|
def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
|
||||||
|
state_dict_text_encoder = text_encoder.state_dict()
|
||||||
|
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
|
||||||
|
state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
|
||||||
|
if len(state_dict_lora) > 0:
|
||||||
|
for name in state_dict_lora:
|
||||||
|
state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
|
||||||
|
text_encoder.load_state_dict(state_dict_text_encoder)
|
||||||
|
|
||||||
@@ -1,20 +1,28 @@
|
|||||||
from .sd_unet import SDUNet, Attention, GEGLU
|
from .sd_unet import SDUNet, Attention, GEGLU
|
||||||
|
from .svd_unet import get_timestep_embedding
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
class TemporalTransformerBlock(torch.nn.Module):
|
class TemporalTransformerBlock(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
|
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32, add_positional_conv=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.add_positional_conv = add_positional_conv
|
||||||
|
|
||||||
# 1. Self-Attn
|
# 1. Self-Attn
|
||||||
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
|
||||||
|
self.pe1 = torch.nn.Parameter(emb)
|
||||||
|
if add_positional_conv:
|
||||||
|
self.positional_conv_1 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
|
||||||
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||||
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||||
|
|
||||||
# 2. Cross-Attn
|
# 2. Cross-Attn
|
||||||
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
|
||||||
|
self.pe2 = torch.nn.Parameter(emb)
|
||||||
|
if add_positional_conv:
|
||||||
|
self.positional_conv_2 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
|
||||||
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||||
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||||
|
|
||||||
@@ -24,19 +32,47 @@ class TemporalTransformerBlock(torch.nn.Module):
|
|||||||
self.ff = torch.nn.Linear(dim * 4, dim)
|
self.ff = torch.nn.Linear(dim * 4, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
||||||
|
if frame_id < max_id:
|
||||||
|
position_id = frame_id
|
||||||
|
else:
|
||||||
|
position_id = (frame_id - max_id) % (repeat_length * 2)
|
||||||
|
if position_id < repeat_length:
|
||||||
|
position_id = max_id - 2 - position_id
|
||||||
|
else:
|
||||||
|
position_id = max_id - 2 * repeat_length + position_id
|
||||||
|
return position_id
|
||||||
|
|
||||||
|
|
||||||
|
def positional_ids(self, num_frames):
|
||||||
|
max_id = self.pe1.shape[1]
|
||||||
|
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
|
||||||
|
return positional_ids
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, batch_size=1):
|
def forward(self, hidden_states, batch_size=1):
|
||||||
|
|
||||||
# 1. Self-Attention
|
# 1. Self-Attention
|
||||||
norm_hidden_states = self.norm1(hidden_states)
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||||
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
|
norm_hidden_states = norm_hidden_states + self.pe1[:, self.positional_ids(norm_hidden_states.shape[1])]
|
||||||
|
if self.add_positional_conv:
|
||||||
|
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
|
||||||
|
norm_hidden_states = self.positional_conv_1(norm_hidden_states)
|
||||||
|
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
|
||||||
|
attn_output = self.attn1(norm_hidden_states)
|
||||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||||
hidden_states = attn_output + hidden_states
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
# 2. Cross-Attention
|
# 2. Cross-Attention
|
||||||
norm_hidden_states = self.norm2(hidden_states)
|
norm_hidden_states = self.norm2(hidden_states)
|
||||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||||
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
|
norm_hidden_states = norm_hidden_states + self.pe2[:, self.positional_ids(norm_hidden_states.shape[1])]
|
||||||
|
if self.add_positional_conv:
|
||||||
|
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
|
||||||
|
norm_hidden_states = self.positional_conv_2(norm_hidden_states)
|
||||||
|
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
|
||||||
|
attn_output = self.attn2(norm_hidden_states)
|
||||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||||
hidden_states = attn_output + hidden_states
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
@@ -51,7 +87,7 @@ class TemporalTransformerBlock(torch.nn.Module):
|
|||||||
|
|
||||||
class TemporalBlock(torch.nn.Module):
|
class TemporalBlock(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, add_positional_conv=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = num_attention_heads * attention_head_dim
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
@@ -62,7 +98,9 @@ class TemporalBlock(torch.nn.Module):
|
|||||||
TemporalTransformerBlock(
|
TemporalTransformerBlock(
|
||||||
inner_dim,
|
inner_dim,
|
||||||
num_attention_heads,
|
num_attention_heads,
|
||||||
attention_head_dim
|
attention_head_dim,
|
||||||
|
max_position_embeddings=32 if add_positional_conv is None else add_positional_conv,
|
||||||
|
add_positional_conv=add_positional_conv
|
||||||
)
|
)
|
||||||
for d in range(num_layers)
|
for d in range(num_layers)
|
||||||
])
|
])
|
||||||
@@ -92,30 +130,30 @@ class TemporalBlock(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SDMotionModel(torch.nn.Module):
|
class SDMotionModel(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, add_positional_conv=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.motion_modules = torch.nn.ModuleList([
|
self.motion_modules = torch.nn.ModuleList([
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
])
|
])
|
||||||
self.call_block_id = {
|
self.call_block_id = {
|
||||||
1: 0,
|
1: 0,
|
||||||
@@ -144,8 +182,7 @@ class SDMotionModel(torch.nn.Module):
|
|||||||
def forward(self):
|
def forward(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDMotionModelStateDictConverter()
|
return SDMotionModelStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
@@ -153,7 +190,42 @@ class SDMotionModelStateDictConverter:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
||||||
|
if frame_id < max_id:
|
||||||
|
position_id = frame_id
|
||||||
|
else:
|
||||||
|
position_id = (frame_id - max_id) % (repeat_length * 2)
|
||||||
|
if position_id < repeat_length:
|
||||||
|
position_id = max_id - 2 - position_id
|
||||||
|
else:
|
||||||
|
position_id = max_id - 2 * repeat_length + position_id
|
||||||
|
return position_id
|
||||||
|
|
||||||
|
def process_positional_conv_parameters(self, state_dict, add_positional_conv):
|
||||||
|
ids = [self.frame_id_to_position_id(i, 16, 15) for i in range(add_positional_conv)]
|
||||||
|
for i in range(21):
|
||||||
|
# Extend positional embedding
|
||||||
|
name = f"motion_modules.{i}.transformer_blocks.0.pe1"
|
||||||
|
state_dict[name] = state_dict[name][:, ids]
|
||||||
|
name = f"motion_modules.{i}.transformer_blocks.0.pe2"
|
||||||
|
state_dict[name] = state_dict[name][:, ids]
|
||||||
|
# add post convolution
|
||||||
|
dim = state_dict[f"motion_modules.{i}.transformer_blocks.0.pe1"].shape[-1]
|
||||||
|
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.bias"
|
||||||
|
state_dict[name] = torch.zeros((dim,))
|
||||||
|
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.bias"
|
||||||
|
state_dict[name] = torch.zeros((dim,))
|
||||||
|
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.weight"
|
||||||
|
param = torch.zeros((dim, dim, 3))
|
||||||
|
param[:, :, 1] = torch.eye(dim, dim)
|
||||||
|
state_dict[name] = param
|
||||||
|
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.weight"
|
||||||
|
param = torch.zeros((dim, dim, 3))
|
||||||
|
param[:, :, 1] = torch.eye(dim, dim)
|
||||||
|
state_dict[name] = param
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict, add_positional_conv=None):
|
||||||
rename_dict = {
|
rename_dict = {
|
||||||
"norm": "norm",
|
"norm": "norm",
|
||||||
"proj_in": "proj_in",
|
"proj_in": "proj_in",
|
||||||
@@ -193,7 +265,9 @@ class SDMotionModelStateDictConverter:
|
|||||||
else:
|
else:
|
||||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
||||||
state_dict_[rename] = state_dict[name]
|
state_dict_[rename] = state_dict[name]
|
||||||
|
if add_positional_conv is not None:
|
||||||
|
state_dict_ = self.process_positional_conv_parameters(state_dict_, add_positional_conv)
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
def from_civitai(self, state_dict, add_positional_conv=None):
|
||||||
return self.from_diffusers(state_dict)
|
return self.from_diffusers(state_dict, add_positional_conv=add_positional_conv)
|
||||||
|
|||||||
115
diffsynth/models/sd_motion_ex.py
Normal file
115
diffsynth/models/sd_motion_ex.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
from .attention import Attention
|
||||||
|
from .svd_unet import get_timestep_embedding
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ExVideoMotionBlock(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, num_attention_heads, attention_head_dim, in_channels, max_position_embeddings=16, num_layers=1, add_positional_conv=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
emb = get_timestep_embedding(torch.arange(max_position_embeddings), in_channels, True, 0).reshape(max_position_embeddings, in_channels, 1, 1)
|
||||||
|
self.positional_embedding = torch.nn.Parameter(emb)
|
||||||
|
self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) if add_positional_conv is not None else None
|
||||||
|
self.norms = torch.nn.ModuleList([torch.nn.LayerNorm(in_channels) for _ in range(num_layers)])
|
||||||
|
self.attns = torch.nn.ModuleList([Attention(q_dim=in_channels, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) for _ in range(num_layers)])
|
||||||
|
|
||||||
|
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
||||||
|
if frame_id < max_id:
|
||||||
|
position_id = frame_id
|
||||||
|
else:
|
||||||
|
position_id = (frame_id - max_id) % (repeat_length * 2)
|
||||||
|
if position_id < repeat_length:
|
||||||
|
position_id = max_id - 2 - position_id
|
||||||
|
else:
|
||||||
|
position_id = max_id - 2 * repeat_length + position_id
|
||||||
|
return position_id
|
||||||
|
|
||||||
|
def positional_ids(self, num_frames):
|
||||||
|
max_id = self.positional_embedding.shape[0]
|
||||||
|
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
|
||||||
|
return positional_ids
|
||||||
|
|
||||||
|
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1, **kwargs):
|
||||||
|
batch, inner_dim, height, width = hidden_states.shape
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
pos_emb = self.positional_ids(batch // batch_size)
|
||||||
|
pos_emb = self.positional_embedding[pos_emb]
|
||||||
|
pos_emb = pos_emb.repeat(batch_size)
|
||||||
|
hidden_states = hidden_states + pos_emb
|
||||||
|
if self.positional_conv is not None:
|
||||||
|
hidden_states = rearrange(hidden_states, "(B T) C H W -> B C T H W", B=batch_size)
|
||||||
|
hidden_states = self.positional_conv(hidden_states)
|
||||||
|
hidden_states = rearrange(hidden_states, "B C T H W -> (B H W) T C")
|
||||||
|
else:
|
||||||
|
hidden_states = rearrange(hidden_states, "(B T) C H W -> (B H W) T C", B=batch_size)
|
||||||
|
|
||||||
|
for norm, attn in zip(self.norms, self.attns):
|
||||||
|
norm_hidden_states = norm(hidden_states)
|
||||||
|
attn_output = attn(norm_hidden_states)
|
||||||
|
hidden_states = hidden_states + attn_output
|
||||||
|
|
||||||
|
hidden_states = rearrange(hidden_states, "(B H W) T C -> (B T) C H W", B=batch_size, H=height, W=width)
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
return hidden_states, time_emb, text_emb, res_stack
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ExVideoMotionModel(torch.nn.Module):
|
||||||
|
def __init__(self, num_layers=2):
|
||||||
|
super().__init__()
|
||||||
|
self.motion_modules = torch.nn.ModuleList([
|
||||||
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||||
|
])
|
||||||
|
self.call_block_id = {
|
||||||
|
1: 0,
|
||||||
|
4: 1,
|
||||||
|
9: 2,
|
||||||
|
12: 3,
|
||||||
|
17: 4,
|
||||||
|
20: 5,
|
||||||
|
24: 6,
|
||||||
|
26: 7,
|
||||||
|
29: 8,
|
||||||
|
32: 9,
|
||||||
|
34: 10,
|
||||||
|
36: 11,
|
||||||
|
40: 12,
|
||||||
|
43: 13,
|
||||||
|
46: 14,
|
||||||
|
50: 15,
|
||||||
|
53: 16,
|
||||||
|
56: 17,
|
||||||
|
60: 18,
|
||||||
|
63: 19,
|
||||||
|
66: 20
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def state_dict_converter(self):
|
||||||
|
pass
|
||||||
@@ -71,8 +71,7 @@ class SDTextEncoder(torch.nn.Module):
|
|||||||
embeds = self.final_layer_norm(embeds)
|
embeds = self.final_layer_norm(embeds)
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDTextEncoderStateDictConverter()
|
return SDTextEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -323,7 +323,7 @@ class SDUNet(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
|
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
|
||||||
# 1. time
|
# 1. time
|
||||||
time_emb = self.time_proj(timestep).to(sample.dtype)
|
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||||
time_emb = self.time_embedding(time_emb)
|
time_emb = self.time_embedding(time_emb)
|
||||||
|
|
||||||
# 2. pre-process
|
# 2. pre-process
|
||||||
@@ -342,8 +342,7 @@ class SDUNet(torch.nn.Module):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDUNetStateDictConverter()
|
return SDUNetStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -90,8 +90,6 @@ class SDVAEDecoder(torch.nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||||
original_dtype = sample.dtype
|
|
||||||
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
|
||||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||||
if tiled:
|
if tiled:
|
||||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
@@ -112,12 +110,10 @@ class SDVAEDecoder(torch.nn.Module):
|
|||||||
hidden_states = self.conv_norm_out(hidden_states)
|
hidden_states = self.conv_norm_out(hidden_states)
|
||||||
hidden_states = self.conv_act(hidden_states)
|
hidden_states = self.conv_act(hidden_states)
|
||||||
hidden_states = self.conv_out(hidden_states)
|
hidden_states = self.conv_out(hidden_states)
|
||||||
hidden_states = hidden_states.to(original_dtype)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDVAEDecoderStateDictConverter()
|
return SDVAEDecoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -50,8 +50,6 @@ class SDVAEEncoder(torch.nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||||
original_dtype = sample.dtype
|
|
||||||
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
|
||||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||||
if tiled:
|
if tiled:
|
||||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
@@ -73,7 +71,6 @@ class SDVAEEncoder(torch.nn.Module):
|
|||||||
hidden_states = self.quant_conv(hidden_states)
|
hidden_states = self.quant_conv(hidden_states)
|
||||||
hidden_states = hidden_states[:, :4]
|
hidden_states = hidden_states[:, :4]
|
||||||
hidden_states *= self.scaling_factor
|
hidden_states *= self.scaling_factor
|
||||||
hidden_states = hidden_states.to(original_dtype)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -94,8 +91,7 @@ class SDVAEEncoder(torch.nn.Module):
|
|||||||
hidden_states = torch.concat(hidden_states, dim=2)
|
hidden_states = torch.concat(hidden_states, dim=2)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDVAEEncoderStateDictConverter()
|
return SDVAEEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,318 +0,0 @@
|
|||||||
import torch
|
|
||||||
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
|
||||||
from .sdxl_unet import SDXLUNet
|
|
||||||
from .tiler import TileWorker
|
|
||||||
from .sd_controlnet import ControlNetConditioningLayer
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class QuickGELU(torch.nn.Module):
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
return x * torch.sigmoid(1.702 * x)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.attn = torch.nn.MultiheadAttention(d_model, n_head)
|
|
||||||
self.ln_1 = torch.nn.LayerNorm(d_model)
|
|
||||||
self.mlp = torch.nn.Sequential(OrderedDict([
|
|
||||||
("c_fc", torch.nn.Linear(d_model, d_model * 4)),
|
|
||||||
("gelu", QuickGELU()),
|
|
||||||
("c_proj", torch.nn.Linear(d_model * 4, d_model))
|
|
||||||
]))
|
|
||||||
self.ln_2 = torch.nn.LayerNorm(d_model)
|
|
||||||
self.attn_mask = attn_mask
|
|
||||||
|
|
||||||
def attention(self, x: torch.Tensor):
|
|
||||||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
|
||||||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
x = x + self.attention(self.ln_1(x))
|
|
||||||
x = x + self.mlp(self.ln_2(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLControlNetUnion(torch.nn.Module):
|
|
||||||
def __init__(self, global_pool=False):
|
|
||||||
super().__init__()
|
|
||||||
self.time_proj = Timesteps(320)
|
|
||||||
self.time_embedding = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(320, 1280),
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(1280, 1280)
|
|
||||||
)
|
|
||||||
self.add_time_proj = Timesteps(256)
|
|
||||||
self.add_time_embedding = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(2816, 1280),
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(1280, 1280)
|
|
||||||
)
|
|
||||||
self.control_type_proj = Timesteps(256)
|
|
||||||
self.control_type_embedding = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(256 * 8, 1280),
|
|
||||||
torch.nn.SiLU(),
|
|
||||||
torch.nn.Linear(1280, 1280)
|
|
||||||
)
|
|
||||||
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
|
|
||||||
self.controlnet_transformer = ResidualAttentionBlock(320, 8)
|
|
||||||
self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
|
|
||||||
self.spatial_ch_projs = torch.nn.Linear(320, 320)
|
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([
|
|
||||||
# DownBlock2D
|
|
||||||
ResnetBlock(320, 320, 1280),
|
|
||||||
PushBlock(),
|
|
||||||
ResnetBlock(320, 320, 1280),
|
|
||||||
PushBlock(),
|
|
||||||
DownSampler(320),
|
|
||||||
PushBlock(),
|
|
||||||
# CrossAttnDownBlock2D
|
|
||||||
ResnetBlock(320, 640, 1280),
|
|
||||||
AttentionBlock(10, 64, 640, 2, 2048),
|
|
||||||
PushBlock(),
|
|
||||||
ResnetBlock(640, 640, 1280),
|
|
||||||
AttentionBlock(10, 64, 640, 2, 2048),
|
|
||||||
PushBlock(),
|
|
||||||
DownSampler(640),
|
|
||||||
PushBlock(),
|
|
||||||
# CrossAttnDownBlock2D
|
|
||||||
ResnetBlock(640, 1280, 1280),
|
|
||||||
AttentionBlock(20, 64, 1280, 10, 2048),
|
|
||||||
PushBlock(),
|
|
||||||
ResnetBlock(1280, 1280, 1280),
|
|
||||||
AttentionBlock(20, 64, 1280, 10, 2048),
|
|
||||||
PushBlock(),
|
|
||||||
# UNetMidBlock2DCrossAttn
|
|
||||||
ResnetBlock(1280, 1280, 1280),
|
|
||||||
AttentionBlock(20, 64, 1280, 10, 2048),
|
|
||||||
ResnetBlock(1280, 1280, 1280),
|
|
||||||
PushBlock()
|
|
||||||
])
|
|
||||||
|
|
||||||
self.controlnet_blocks = torch.nn.ModuleList([
|
|
||||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
|
||||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
|
||||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
|
||||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
|
||||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
|
||||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
|
||||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
|
||||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
|
||||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
|
||||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.global_pool = global_pool
|
|
||||||
|
|
||||||
# 0 -- openpose
|
|
||||||
# 1 -- depth
|
|
||||||
# 2 -- hed/pidi/scribble/ted
|
|
||||||
# 3 -- canny/lineart/anime_lineart/mlsd
|
|
||||||
# 4 -- normal
|
|
||||||
# 5 -- segment
|
|
||||||
# 6 -- tile
|
|
||||||
# 7 -- repaint
|
|
||||||
self.task_id = {
|
|
||||||
"openpose": 0,
|
|
||||||
"depth": 1,
|
|
||||||
"softedge": 2,
|
|
||||||
"canny": 3,
|
|
||||||
"lineart": 3,
|
|
||||||
"lineart_anime": 3,
|
|
||||||
"tile": 6,
|
|
||||||
"inpaint": 7
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
|
|
||||||
controlnet_cond = self.controlnet_conv_in(conditioning)
|
|
||||||
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
|
||||||
feat_seq = feat_seq + self.task_embedding[task_id]
|
|
||||||
x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1)
|
|
||||||
x = self.controlnet_transformer(x)
|
|
||||||
|
|
||||||
alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
controlnet_cond_fuser = controlnet_cond + alpha
|
|
||||||
|
|
||||||
hidden_states = hidden_states + controlnet_cond_fuser
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
sample, timestep, encoder_hidden_states,
|
|
||||||
conditioning, processor_id, add_time_id, add_text_embeds,
|
|
||||||
tiled=False, tile_size=64, tile_stride=32,
|
|
||||||
unet:SDXLUNet=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
task_id = self.task_id[processor_id]
|
|
||||||
|
|
||||||
# 1. time
|
|
||||||
t_emb = self.time_proj(timestep).to(sample.dtype)
|
|
||||||
t_emb = self.time_embedding(t_emb)
|
|
||||||
|
|
||||||
time_embeds = self.add_time_proj(add_time_id)
|
|
||||||
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
|
|
||||||
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
|
|
||||||
add_embeds = add_embeds.to(sample.dtype)
|
|
||||||
if unet is not None and unet.is_kolors:
|
|
||||||
add_embeds = unet.add_time_embedding(add_embeds)
|
|
||||||
else:
|
|
||||||
add_embeds = self.add_time_embedding(add_embeds)
|
|
||||||
|
|
||||||
control_type = torch.zeros((sample.shape[0], 8), dtype=sample.dtype, device=sample.device)
|
|
||||||
control_type[:, task_id] = 1
|
|
||||||
control_embeds = self.control_type_proj(control_type.flatten())
|
|
||||||
control_embeds = control_embeds.reshape((sample.shape[0], -1))
|
|
||||||
control_embeds = control_embeds.to(sample.dtype)
|
|
||||||
control_embeds = self.control_type_embedding(control_embeds)
|
|
||||||
time_emb = t_emb + add_embeds + control_embeds
|
|
||||||
|
|
||||||
# 2. pre-process
|
|
||||||
height, width = sample.shape[2], sample.shape[3]
|
|
||||||
hidden_states = self.conv_in(sample)
|
|
||||||
hidden_states = self.fuse_condition_to_input(hidden_states, task_id, conditioning)
|
|
||||||
text_emb = encoder_hidden_states
|
|
||||||
if unet is not None and unet.is_kolors:
|
|
||||||
text_emb = unet.text_intermediate_proj(text_emb)
|
|
||||||
res_stack = [hidden_states]
|
|
||||||
|
|
||||||
# 3. blocks
|
|
||||||
for i, block in enumerate(self.blocks):
|
|
||||||
if tiled and not isinstance(block, PushBlock):
|
|
||||||
_, _, inter_height, _ = hidden_states.shape
|
|
||||||
resize_scale = inter_height / height
|
|
||||||
hidden_states = TileWorker().tiled_forward(
|
|
||||||
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
|
||||||
hidden_states,
|
|
||||||
int(tile_size * resize_scale),
|
|
||||||
int(tile_stride * resize_scale),
|
|
||||||
tile_device=hidden_states.device,
|
|
||||||
tile_dtype=hidden_states.dtype
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
|
||||||
|
|
||||||
# 4. ControlNet blocks
|
|
||||||
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
|
||||||
|
|
||||||
# pool
|
|
||||||
if self.global_pool:
|
|
||||||
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
|
|
||||||
|
|
||||||
return controlnet_res_stack
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return SDXLControlNetUnionStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLControlNetUnionStateDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
# architecture
|
|
||||||
block_types = [
|
|
||||||
"ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock",
|
|
||||||
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock",
|
|
||||||
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock",
|
|
||||||
"ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock"
|
|
||||||
]
|
|
||||||
|
|
||||||
# controlnet_rename_dict
|
|
||||||
controlnet_rename_dict = {
|
|
||||||
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
|
|
||||||
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
|
|
||||||
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
|
|
||||||
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
|
|
||||||
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
|
|
||||||
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
|
|
||||||
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
|
|
||||||
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
|
|
||||||
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
|
|
||||||
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
|
|
||||||
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
|
|
||||||
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
|
|
||||||
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
|
|
||||||
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
|
|
||||||
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
|
|
||||||
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
|
||||||
"control_add_embedding.linear_1.weight": "control_type_embedding.0.weight",
|
|
||||||
"control_add_embedding.linear_1.bias": "control_type_embedding.0.bias",
|
|
||||||
"control_add_embedding.linear_2.weight": "control_type_embedding.2.weight",
|
|
||||||
"control_add_embedding.linear_2.bias": "control_type_embedding.2.bias",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Rename each parameter
|
|
||||||
name_list = sorted([name for name in state_dict])
|
|
||||||
rename_dict = {}
|
|
||||||
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
|
||||||
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
|
|
||||||
for name in name_list:
|
|
||||||
names = name.split(".")
|
|
||||||
if names[0] in ["conv_in", "conv_norm_out", "conv_out", "task_embedding", "spatial_ch_projs"]:
|
|
||||||
pass
|
|
||||||
elif name in controlnet_rename_dict:
|
|
||||||
names = controlnet_rename_dict[name].split(".")
|
|
||||||
elif names[0] == "controlnet_down_blocks":
|
|
||||||
names[0] = "controlnet_blocks"
|
|
||||||
elif names[0] == "controlnet_mid_block":
|
|
||||||
names = ["controlnet_blocks", "9", names[-1]]
|
|
||||||
elif names[0] in ["time_embedding", "add_embedding"]:
|
|
||||||
if names[0] == "add_embedding":
|
|
||||||
names[0] = "add_time_embedding"
|
|
||||||
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
|
|
||||||
elif names[0] == "control_add_embedding":
|
|
||||||
names[0] = "control_type_embedding"
|
|
||||||
elif names[0] == "transformer_layes":
|
|
||||||
names[0] = "controlnet_transformer"
|
|
||||||
names.pop(1)
|
|
||||||
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
|
||||||
if names[0] == "mid_block":
|
|
||||||
names.insert(1, "0")
|
|
||||||
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
|
|
||||||
block_type_with_id = ".".join(names[:4])
|
|
||||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
|
||||||
block_id[block_type] += 1
|
|
||||||
last_block_type_with_id[block_type] = block_type_with_id
|
|
||||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
|
||||||
block_id[block_type] += 1
|
|
||||||
block_type_with_id = ".".join(names[:4])
|
|
||||||
names = ["blocks", str(block_id[block_type])] + names[4:]
|
|
||||||
if "ff" in names:
|
|
||||||
ff_index = names.index("ff")
|
|
||||||
component = ".".join(names[ff_index:ff_index+3])
|
|
||||||
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
|
||||||
names = names[:ff_index] + [component] + names[ff_index+3:]
|
|
||||||
if "to_out" in names:
|
|
||||||
names.pop(names.index("to_out") + 1)
|
|
||||||
else:
|
|
||||||
print(name, state_dict[name].shape)
|
|
||||||
# raise ValueError(f"Unknown parameters: {name}")
|
|
||||||
rename_dict[name] = ".".join(names)
|
|
||||||
|
|
||||||
# Convert state_dict
|
|
||||||
state_dict_ = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name not in rename_dict:
|
|
||||||
continue
|
|
||||||
if ".proj_in." in name or ".proj_out." in name:
|
|
||||||
param = param.squeeze()
|
|
||||||
state_dict_[rename_dict[name]] = param
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return self.from_diffusers(state_dict)
|
|
||||||
@@ -96,8 +96,7 @@ class SDXLIpAdapter(torch.nn.Module):
|
|||||||
}
|
}
|
||||||
return ip_kv_dict
|
return ip_kv_dict
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDXLIpAdapterStateDictConverter()
|
return SDXLIpAdapterStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -49,8 +49,7 @@ class SDXLMotionModel(torch.nn.Module):
|
|||||||
def forward(self):
|
def forward(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDMotionModelStateDictConverter()
|
return SDMotionModelStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,8 +36,7 @@ class SDXLTextEncoder(torch.nn.Module):
|
|||||||
break
|
break
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDXLTextEncoderStateDictConverter()
|
return SDXLTextEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
@@ -81,8 +80,7 @@ class SDXLTextEncoder2(torch.nn.Module):
|
|||||||
pooled_embeds = self.text_projection(pooled_embeds)
|
pooled_embeds = self.text_projection(pooled_embeds)
|
||||||
return pooled_embeds, hidden_states
|
return pooled_embeds, hidden_states
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDXLTextEncoder2StateDictConverter()
|
return SDXLTextEncoder2StateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock
|
|||||||
|
|
||||||
|
|
||||||
class SDXLUNet(torch.nn.Module):
|
class SDXLUNet(torch.nn.Module):
|
||||||
def __init__(self, is_kolors=False):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_proj = Timesteps(320)
|
self.time_proj = Timesteps(320)
|
||||||
self.time_embedding = torch.nn.Sequential(
|
self.time_embedding = torch.nn.Sequential(
|
||||||
@@ -13,12 +13,11 @@ class SDXLUNet(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.add_time_proj = Timesteps(256)
|
self.add_time_proj = Timesteps(256)
|
||||||
self.add_time_embedding = torch.nn.Sequential(
|
self.add_time_embedding = torch.nn.Sequential(
|
||||||
torch.nn.Linear(5632 if is_kolors else 2816, 1280),
|
torch.nn.Linear(2816, 1280),
|
||||||
torch.nn.SiLU(),
|
torch.nn.SiLU(),
|
||||||
torch.nn.Linear(1280, 1280)
|
torch.nn.Linear(1280, 1280)
|
||||||
)
|
)
|
||||||
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
||||||
self.text_intermediate_proj = torch.nn.Linear(4096, 2048) if is_kolors else None
|
|
||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([
|
self.blocks = torch.nn.ModuleList([
|
||||||
# DownBlock2D
|
# DownBlock2D
|
||||||
@@ -83,17 +82,13 @@ class SDXLUNet(torch.nn.Module):
|
|||||||
self.conv_act = torch.nn.SiLU()
|
self.conv_act = torch.nn.SiLU()
|
||||||
self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
|
self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
|
||||||
|
|
||||||
self.is_kolors = is_kolors
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,
|
sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,
|
||||||
tiled=False, tile_size=64, tile_stride=8,
|
tiled=False, tile_size=64, tile_stride=8, **kwargs
|
||||||
use_gradient_checkpointing=False,
|
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
# 1. time
|
# 1. time
|
||||||
t_emb = self.time_proj(timestep).to(sample.dtype)
|
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||||
t_emb = self.time_embedding(t_emb)
|
t_emb = self.time_embedding(t_emb)
|
||||||
|
|
||||||
time_embeds = self.add_time_proj(add_time_id)
|
time_embeds = self.add_time_proj(add_time_id)
|
||||||
@@ -107,26 +102,15 @@ class SDXLUNet(torch.nn.Module):
|
|||||||
# 2. pre-process
|
# 2. pre-process
|
||||||
height, width = sample.shape[2], sample.shape[3]
|
height, width = sample.shape[2], sample.shape[3]
|
||||||
hidden_states = self.conv_in(sample)
|
hidden_states = self.conv_in(sample)
|
||||||
text_emb = encoder_hidden_states if self.text_intermediate_proj is None else self.text_intermediate_proj(encoder_hidden_states)
|
text_emb = encoder_hidden_states
|
||||||
res_stack = [hidden_states]
|
res_stack = [hidden_states]
|
||||||
|
|
||||||
# 3. blocks
|
# 3. blocks
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock)):
|
hidden_states, time_emb, text_emb, res_stack = block(
|
||||||
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
hidden_states, time_emb, text_emb, res_stack,
|
||||||
create_custom_forward(block),
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||||
hidden_states, time_emb, text_emb, res_stack,
|
)
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, time_emb, text_emb, res_stack = block(
|
|
||||||
hidden_states, time_emb, text_emb, res_stack,
|
|
||||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. output
|
# 4. output
|
||||||
hidden_states = self.conv_norm_out(hidden_states)
|
hidden_states = self.conv_norm_out(hidden_states)
|
||||||
@@ -135,8 +119,7 @@ class SDXLUNet(torch.nn.Module):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDXLUNetStateDictConverter()
|
return SDXLUNetStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
@@ -165,8 +148,6 @@ class SDXLUNetStateDictConverter:
|
|||||||
names = name.split(".")
|
names = name.split(".")
|
||||||
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
||||||
pass
|
pass
|
||||||
elif names[0] in ["encoder_hid_proj"]:
|
|
||||||
names[0] = "text_intermediate_proj"
|
|
||||||
elif names[0] in ["time_embedding", "add_embedding"]:
|
elif names[0] in ["time_embedding", "add_embedding"]:
|
||||||
if names[0] == "add_embedding":
|
if names[0] == "add_embedding":
|
||||||
names[0] = "add_time_embedding"
|
names[0] = "add_time_embedding"
|
||||||
@@ -200,10 +181,7 @@ class SDXLUNetStateDictConverter:
|
|||||||
if ".proj_in." in name or ".proj_out." in name:
|
if ".proj_in." in name or ".proj_out." in name:
|
||||||
param = param.squeeze()
|
param = param.squeeze()
|
||||||
state_dict_[rename_dict[name]] = param
|
state_dict_[rename_dict[name]] = param
|
||||||
if "text_intermediate_proj.weight" in state_dict_:
|
return state_dict_
|
||||||
return state_dict_, {"is_kolors": True}
|
|
||||||
else:
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
def from_civitai(self, state_dict):
|
||||||
rename_dict = {
|
rename_dict = {
|
||||||
@@ -1895,7 +1873,4 @@ class SDXLUNetStateDictConverter:
|
|||||||
if ".proj_in." in name or ".proj_out." in name:
|
if ".proj_in." in name or ".proj_out." in name:
|
||||||
param = param.squeeze()
|
param = param.squeeze()
|
||||||
state_dict_[rename_dict[name]] = param
|
state_dict_[rename_dict[name]] = param
|
||||||
if "text_intermediate_proj.weight" in state_dict_:
|
return state_dict_
|
||||||
return state_dict_, {"is_kolors": True}
|
|
||||||
else:
|
|
||||||
return state_dict_
|
|
||||||
|
|||||||
@@ -2,23 +2,14 @@ from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
|
|||||||
|
|
||||||
|
|
||||||
class SDXLVAEDecoder(SDVAEDecoder):
|
class SDXLVAEDecoder(SDVAEDecoder):
|
||||||
def __init__(self, upcast_to_float32=True):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scaling_factor = 0.13025
|
self.scaling_factor = 0.13025
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDXLVAEDecoderStateDictConverter()
|
return SDXLVAEDecoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict = super().from_diffusers(state_dict)
|
|
||||||
return state_dict, {"upcast_to_float32": True}
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
state_dict = super().from_civitai(state_dict)
|
|
||||||
return state_dict, {"upcast_to_float32": True}
|
|
||||||
|
|||||||
@@ -2,23 +2,14 @@ from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
|
|||||||
|
|
||||||
|
|
||||||
class SDXLVAEEncoder(SDVAEEncoder):
|
class SDXLVAEEncoder(SDVAEEncoder):
|
||||||
def __init__(self, upcast_to_float32=True):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scaling_factor = 0.13025
|
self.scaling_factor = 0.13025
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SDXLVAEEncoderStateDictConverter()
|
return SDXLVAEEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
state_dict = super().from_diffusers(state_dict)
|
|
||||||
return state_dict, {"upcast_to_float32": True}
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
state_dict = super().from_civitai(state_dict)
|
|
||||||
return state_dict, {"upcast_to_float32": True}
|
|
||||||
|
|||||||
@@ -44,8 +44,7 @@ class SVDImageEncoder(torch.nn.Module):
|
|||||||
embeds = self.visual_projection(embeds)
|
embeds = self.visual_projection(embeds)
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SVDImageEncoderStateDictConverter()
|
return SVDImageEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ def get_timestep_embedding(
|
|||||||
downscale_freq_shift: float = 1,
|
downscale_freq_shift: float = 1,
|
||||||
scale: float = 1,
|
scale: float = 1,
|
||||||
max_period: int = 10000,
|
max_period: int = 10000,
|
||||||
computation_device = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||||
@@ -58,11 +57,11 @@ def get_timestep_embedding(
|
|||||||
|
|
||||||
half_dim = embedding_dim // 2
|
half_dim = embedding_dim // 2
|
||||||
exponent = -math.log(max_period) * torch.arange(
|
exponent = -math.log(max_period) * torch.arange(
|
||||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
|
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||||
)
|
)
|
||||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||||
|
|
||||||
emb = torch.exp(exponent).to(timesteps.device)
|
emb = torch.exp(exponent)
|
||||||
emb = timesteps[:, None].float() * emb[None, :]
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
|
||||||
# scale embeddings
|
# scale embeddings
|
||||||
@@ -82,12 +81,11 @@ def get_timestep_embedding(
|
|||||||
|
|
||||||
|
|
||||||
class TemporalTimesteps(torch.nn.Module):
|
class TemporalTimesteps(torch.nn.Module):
|
||||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None):
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.flip_sin_to_cos = flip_sin_to_cos
|
self.flip_sin_to_cos = flip_sin_to_cos
|
||||||
self.downscale_freq_shift = downscale_freq_shift
|
self.downscale_freq_shift = downscale_freq_shift
|
||||||
self.computation_device = computation_device
|
|
||||||
|
|
||||||
def forward(self, timesteps):
|
def forward(self, timesteps):
|
||||||
t_emb = get_timestep_embedding(
|
t_emb = get_timestep_embedding(
|
||||||
@@ -95,7 +93,6 @@ class TemporalTimesteps(torch.nn.Module):
|
|||||||
self.num_channels,
|
self.num_channels,
|
||||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||||
downscale_freq_shift=self.downscale_freq_shift,
|
downscale_freq_shift=self.downscale_freq_shift,
|
||||||
computation_device=self.computation_device,
|
|
||||||
)
|
)
|
||||||
return t_emb
|
return t_emb
|
||||||
|
|
||||||
@@ -410,8 +407,7 @@ class SVDUNet(torch.nn.Module):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SVDUNetStateDictConverter()
|
return SVDUNetStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -199,8 +199,7 @@ class SVDVAEDecoder(torch.nn.Module):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SVDVAEDecoderStateDictConverter()
|
return SVDVAEDecoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,7 @@ class SVDVAEEncoder(SDVAEEncoder):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.scaling_factor = 0.13025
|
self.scaling_factor = 0.13025
|
||||||
|
|
||||||
@staticmethod
|
def state_dict_converter(self):
|
||||||
def state_dict_converter():
|
|
||||||
return SVDVAEEncoderStateDictConverter()
|
return SVDVAEEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -103,132 +103,4 @@ class TileWorker:
|
|||||||
|
|
||||||
# Done!
|
# Done!
|
||||||
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FastTileWorker:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def build_mask(self, data, is_bound):
|
|
||||||
_, _, H, W = data.shape
|
|
||||||
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
|
||||||
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
|
||||||
border_width = (H + W) // 4
|
|
||||||
pad = torch.ones_like(h) * border_width
|
|
||||||
mask = torch.stack([
|
|
||||||
pad if is_bound[0] else h + 1,
|
|
||||||
pad if is_bound[1] else H - h,
|
|
||||||
pad if is_bound[2] else w + 1,
|
|
||||||
pad if is_bound[3] else W - w
|
|
||||||
]).min(dim=0).values
|
|
||||||
mask = mask.clip(1, border_width)
|
|
||||||
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
|
||||||
mask = rearrange(mask, "H W -> 1 H W")
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
|
||||||
# Prepare
|
|
||||||
B, C, H, W = model_input.shape
|
|
||||||
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
|
||||||
weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device)
|
|
||||||
values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device)
|
|
||||||
|
|
||||||
# Split tasks
|
|
||||||
tasks = []
|
|
||||||
for h in range(0, H, tile_stride):
|
|
||||||
for w in range(0, W, tile_stride):
|
|
||||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
|
||||||
continue
|
|
||||||
h_, w_ = h + tile_size, w + tile_size
|
|
||||||
if h_ > H: h, h_ = H - tile_size, H
|
|
||||||
if w_ > W: w, w_ = W - tile_size, W
|
|
||||||
tasks.append((h, h_, w, w_))
|
|
||||||
|
|
||||||
# Run
|
|
||||||
for hl, hr, wl, wr in tasks:
|
|
||||||
# Forward
|
|
||||||
hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device)
|
|
||||||
|
|
||||||
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
|
|
||||||
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
|
||||||
weight[:, :, hl:hr, wl:wr] += mask
|
|
||||||
values /= weight
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TileWorker2Dto3D:
|
|
||||||
"""
|
|
||||||
Process 3D tensors, but only enable TileWorker on 2D.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def build_mask(self, T, H, W, dtype, device, is_bound, border_width):
|
|
||||||
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
|
||||||
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
|
||||||
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
|
||||||
border_width = (H + W) // 4 if border_width is None else border_width
|
|
||||||
pad = torch.ones_like(h) * border_width
|
|
||||||
mask = torch.stack([
|
|
||||||
pad if is_bound[0] else t + 1,
|
|
||||||
pad if is_bound[1] else T - t,
|
|
||||||
pad if is_bound[2] else h + 1,
|
|
||||||
pad if is_bound[3] else H - h,
|
|
||||||
pad if is_bound[4] else w + 1,
|
|
||||||
pad if is_bound[5] else W - w
|
|
||||||
]).min(dim=0).values
|
|
||||||
mask = mask.clip(1, border_width)
|
|
||||||
mask = (mask / border_width).to(dtype=dtype, device=device)
|
|
||||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def tiled_forward(
|
|
||||||
self,
|
|
||||||
forward_fn,
|
|
||||||
model_input,
|
|
||||||
tile_size, tile_stride,
|
|
||||||
tile_device="cpu", tile_dtype=torch.float32,
|
|
||||||
computation_device="cuda", computation_dtype=torch.float32,
|
|
||||||
border_width=None, scales=[1, 1, 1, 1],
|
|
||||||
progress_bar=lambda x:x
|
|
||||||
):
|
|
||||||
B, C, T, H, W = model_input.shape
|
|
||||||
scale_C, scale_T, scale_H, scale_W = scales
|
|
||||||
tile_size_H, tile_size_W = tile_size
|
|
||||||
tile_stride_H, tile_stride_W = tile_stride
|
|
||||||
|
|
||||||
value = torch.zeros((B, int(C*scale_C), int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
|
|
||||||
weight = torch.zeros((1, 1, int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
|
|
||||||
|
|
||||||
# Split tasks
|
|
||||||
tasks = []
|
|
||||||
for h in range(0, H, tile_stride_H):
|
|
||||||
for w in range(0, W, tile_stride_W):
|
|
||||||
if (h-tile_stride_H >= 0 and h-tile_stride_H+tile_size_H >= H) or (w-tile_stride_W >= 0 and w-tile_stride_W+tile_size_W >= W):
|
|
||||||
continue
|
|
||||||
h_, w_ = h + tile_size_H, w + tile_size_W
|
|
||||||
if h_ > H: h, h_ = max(H - tile_size_H, 0), H
|
|
||||||
if w_ > W: w, w_ = max(W - tile_size_W, 0), W
|
|
||||||
tasks.append((h, h_, w, w_))
|
|
||||||
|
|
||||||
# Run
|
|
||||||
for hl, hr, wl, wr in progress_bar(tasks):
|
|
||||||
mask = self.build_mask(
|
|
||||||
int(T*scale_T), int((hr-hl)*scale_H), int((wr-wl)*scale_W),
|
|
||||||
tile_dtype, tile_device,
|
|
||||||
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W),
|
|
||||||
border_width=border_width
|
|
||||||
)
|
|
||||||
grid_input = model_input[:, :, :, hl:hr, wl:wr].to(dtype=computation_dtype, device=computation_device)
|
|
||||||
grid_output = forward_fn(grid_input).to(dtype=tile_dtype, device=tile_device)
|
|
||||||
value[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += grid_output * mask
|
|
||||||
weight[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += mask
|
|
||||||
value = value / weight
|
|
||||||
return value
|
|
||||||
@@ -1,182 +0,0 @@
|
|||||||
import torch, os
|
|
||||||
from safetensors import safe_open
|
|
||||||
from contextlib import contextmanager
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
|
||||||
|
|
||||||
old_register_parameter = torch.nn.Module.register_parameter
|
|
||||||
if include_buffers:
|
|
||||||
old_register_buffer = torch.nn.Module.register_buffer
|
|
||||||
|
|
||||||
def register_empty_parameter(module, name, param):
|
|
||||||
old_register_parameter(module, name, param)
|
|
||||||
if param is not None:
|
|
||||||
param_cls = type(module._parameters[name])
|
|
||||||
kwargs = module._parameters[name].__dict__
|
|
||||||
kwargs["requires_grad"] = param.requires_grad
|
|
||||||
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
|
||||||
|
|
||||||
def register_empty_buffer(module, name, buffer, persistent=True):
|
|
||||||
old_register_buffer(module, name, buffer, persistent=persistent)
|
|
||||||
if buffer is not None:
|
|
||||||
module._buffers[name] = module._buffers[name].to(device)
|
|
||||||
|
|
||||||
def patch_tensor_constructor(fn):
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
kwargs["device"] = device
|
|
||||||
return fn(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
if include_buffers:
|
|
||||||
tensor_constructors_to_patch = {
|
|
||||||
torch_function_name: getattr(torch, torch_function_name)
|
|
||||||
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
tensor_constructors_to_patch = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
torch.nn.Module.register_parameter = register_empty_parameter
|
|
||||||
if include_buffers:
|
|
||||||
torch.nn.Module.register_buffer = register_empty_buffer
|
|
||||||
for torch_function_name in tensor_constructors_to_patch.keys():
|
|
||||||
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
torch.nn.Module.register_parameter = old_register_parameter
|
|
||||||
if include_buffers:
|
|
||||||
torch.nn.Module.register_buffer = old_register_buffer
|
|
||||||
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
|
||||||
setattr(torch, torch_function_name, old_torch_function)
|
|
||||||
|
|
||||||
def load_state_dict_from_folder(file_path, torch_dtype=None):
|
|
||||||
state_dict = {}
|
|
||||||
for file_name in os.listdir(file_path):
|
|
||||||
if "." in file_name and file_name.split(".")[-1] in [
|
|
||||||
"safetensors", "bin", "ckpt", "pth", "pt"
|
|
||||||
]:
|
|
||||||
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_path, torch_dtype=None):
|
|
||||||
if file_path.endswith(".safetensors"):
|
|
||||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
|
||||||
else:
|
|
||||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
|
||||||
state_dict = {}
|
|
||||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
|
||||||
for k in f.keys():
|
|
||||||
state_dict[k] = f.get_tensor(k)
|
|
||||||
if torch_dtype is not None:
|
|
||||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
|
||||||
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
|
||||||
if torch_dtype is not None:
|
|
||||||
for i in state_dict:
|
|
||||||
if isinstance(state_dict[i], torch.Tensor):
|
|
||||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def search_for_embeddings(state_dict):
|
|
||||||
embeddings = []
|
|
||||||
for k in state_dict:
|
|
||||||
if isinstance(state_dict[k], torch.Tensor):
|
|
||||||
embeddings.append(state_dict[k])
|
|
||||||
elif isinstance(state_dict[k], dict):
|
|
||||||
embeddings += search_for_embeddings(state_dict[k])
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def search_parameter(param, state_dict):
|
|
||||||
for name, param_ in state_dict.items():
|
|
||||||
if param.numel() == param_.numel():
|
|
||||||
if param.shape == param_.shape:
|
|
||||||
if torch.dist(param, param_) < 1e-3:
|
|
||||||
return name
|
|
||||||
else:
|
|
||||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
|
|
||||||
return name
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
|
||||||
matched_keys = set()
|
|
||||||
with torch.no_grad():
|
|
||||||
for name in source_state_dict:
|
|
||||||
rename = search_parameter(source_state_dict[name], target_state_dict)
|
|
||||||
if rename is not None:
|
|
||||||
print(f'"{name}": "{rename}",')
|
|
||||||
matched_keys.add(rename)
|
|
||||||
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
|
||||||
length = source_state_dict[name].shape[0] // 3
|
|
||||||
rename = []
|
|
||||||
for i in range(3):
|
|
||||||
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
|
||||||
if None not in rename:
|
|
||||||
print(f'"{name}": {rename},')
|
|
||||||
for rename_ in rename:
|
|
||||||
matched_keys.add(rename_)
|
|
||||||
for name in target_state_dict:
|
|
||||||
if name not in matched_keys:
|
|
||||||
print("Cannot find", name, target_state_dict[name].shape)
|
|
||||||
|
|
||||||
|
|
||||||
def search_for_files(folder, extensions):
|
|
||||||
files = []
|
|
||||||
if os.path.isdir(folder):
|
|
||||||
for file in sorted(os.listdir(folder)):
|
|
||||||
files += search_for_files(os.path.join(folder, file), extensions)
|
|
||||||
elif os.path.isfile(folder):
|
|
||||||
for extension in extensions:
|
|
||||||
if folder.endswith(extension):
|
|
||||||
files.append(folder)
|
|
||||||
break
|
|
||||||
return files
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
|
||||||
keys = []
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
if isinstance(key, str):
|
|
||||||
if isinstance(value, torch.Tensor):
|
|
||||||
if with_shape:
|
|
||||||
shape = "_".join(map(str, list(value.shape)))
|
|
||||||
keys.append(key + ":" + shape)
|
|
||||||
keys.append(key)
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
|
||||||
keys.sort()
|
|
||||||
keys_str = ",".join(keys)
|
|
||||||
return keys_str
|
|
||||||
|
|
||||||
|
|
||||||
def split_state_dict_with_prefix(state_dict):
|
|
||||||
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
|
||||||
prefix_dict = {}
|
|
||||||
for key in keys:
|
|
||||||
prefix = key if "." not in key else key.split(".")[0]
|
|
||||||
if prefix not in prefix_dict:
|
|
||||||
prefix_dict[prefix] = []
|
|
||||||
prefix_dict[prefix].append(key)
|
|
||||||
state_dicts = []
|
|
||||||
for prefix, keys in prefix_dict.items():
|
|
||||||
sub_state_dict = {key: state_dict[key] for key in keys}
|
|
||||||
state_dicts.append(sub_state_dict)
|
|
||||||
return state_dicts
|
|
||||||
|
|
||||||
|
|
||||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
|
||||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
|
||||||
keys_str = keys_str.encode(encoding="UTF-8")
|
|
||||||
return hashlib.md5(keys_str).hexdigest()
|
|
||||||
@@ -1,13 +1,6 @@
|
|||||||
from .sd_image import SDImagePipeline
|
from .stable_diffusion import SDImagePipeline
|
||||||
from .sd_video import SDVideoPipeline
|
from .stable_diffusion_xl import SDXLImagePipeline
|
||||||
from .sdxl_image import SDXLImagePipeline
|
from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
|
||||||
from .sdxl_video import SDXLVideoPipeline
|
from .stable_diffusion_xl_video import SDXLVideoPipeline
|
||||||
from .sd3_image import SD3ImagePipeline
|
from .stable_video_diffusion import SVDVideoPipeline
|
||||||
from .hunyuan_image import HunyuanDiTImagePipeline
|
from .hunyuan_dit import HunyuanDiTImagePipeline
|
||||||
from .svd_video import SVDVideoPipeline
|
|
||||||
from .flux_image import FluxImagePipeline
|
|
||||||
from .cog_video import CogVideoPipeline
|
|
||||||
from .omnigen_image import OmnigenImagePipeline
|
|
||||||
from .pipeline_runner import SDVideoPipelineRunner
|
|
||||||
from .hunyuan_video import HunyuanVideoPipeline
|
|
||||||
KolorsImagePipeline = SDXLImagePipeline
|
|
||||||
|
|||||||
@@ -1,117 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision.transforms import GaussianBlur
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BasePipeline(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.torch_dtype = torch_dtype
|
|
||||||
self.height_division_factor = height_division_factor
|
|
||||||
self.width_division_factor = width_division_factor
|
|
||||||
self.cpu_offload = False
|
|
||||||
self.model_names = []
|
|
||||||
|
|
||||||
|
|
||||||
def check_resize_height_width(self, height, width):
|
|
||||||
if height % self.height_division_factor != 0:
|
|
||||||
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
|
||||||
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
|
|
||||||
if width % self.width_division_factor != 0:
|
|
||||||
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
|
||||||
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
|
|
||||||
return height, width
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_image(self, image):
|
|
||||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_images(self, images):
|
|
||||||
return [self.preprocess_image(image) for image in images]
|
|
||||||
|
|
||||||
|
|
||||||
def vae_output_to_image(self, vae_output):
|
|
||||||
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
|
|
||||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def vae_output_to_video(self, vae_output):
|
|
||||||
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
|
||||||
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
|
|
||||||
return video
|
|
||||||
|
|
||||||
|
|
||||||
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
|
|
||||||
if len(latents) > 0:
|
|
||||||
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
|
|
||||||
height, width = value.shape[-2:]
|
|
||||||
weight = torch.ones_like(value)
|
|
||||||
for latent, mask, scale in zip(latents, masks, scales):
|
|
||||||
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
|
||||||
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
|
|
||||||
mask = blur(mask)
|
|
||||||
value += latent * mask * scale
|
|
||||||
weight += mask * scale
|
|
||||||
value /= weight
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
|
|
||||||
if special_kwargs is None:
|
|
||||||
noise_pred_global = inference_callback(prompt_emb_global)
|
|
||||||
else:
|
|
||||||
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
|
|
||||||
if special_local_kwargs_list is None:
|
|
||||||
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
|
|
||||||
else:
|
|
||||||
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
|
|
||||||
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
|
||||||
return noise_pred
|
|
||||||
|
|
||||||
|
|
||||||
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
|
|
||||||
local_prompts = local_prompts or []
|
|
||||||
masks = masks or []
|
|
||||||
mask_scales = mask_scales or []
|
|
||||||
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
|
||||||
prompt = extended_prompt_dict.get("prompt", prompt)
|
|
||||||
local_prompts += extended_prompt_dict.get("prompts", [])
|
|
||||||
masks += extended_prompt_dict.get("masks", [])
|
|
||||||
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
|
|
||||||
return prompt, local_prompts, masks, mask_scales
|
|
||||||
|
|
||||||
|
|
||||||
def enable_cpu_offload(self):
|
|
||||||
self.cpu_offload = True
|
|
||||||
|
|
||||||
|
|
||||||
def load_models_to_device(self, loadmodel_names=[]):
|
|
||||||
# only load models to device if cpu_offload is enabled
|
|
||||||
if not self.cpu_offload:
|
|
||||||
return
|
|
||||||
# offload the unneeded models to cpu
|
|
||||||
for model_name in self.model_names:
|
|
||||||
if model_name not in loadmodel_names:
|
|
||||||
model = getattr(self, model_name)
|
|
||||||
if model is not None:
|
|
||||||
model.cpu()
|
|
||||||
# load the needed models to device
|
|
||||||
for model_name in loadmodel_names:
|
|
||||||
model = getattr(self, model_name)
|
|
||||||
if model is not None:
|
|
||||||
model.to(self.device)
|
|
||||||
# fresh the cuda cache
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
|
|
||||||
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
|
|
||||||
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
|
||||||
return noise
|
|
||||||
@@ -1,135 +0,0 @@
|
|||||||
from ..models import ModelManager, FluxTextEncoder2, CogDiT, CogVAEEncoder, CogVAEDecoder
|
|
||||||
from ..prompters import CogPrompter
|
|
||||||
from ..schedulers import EnhancedDDIMScheduler
|
|
||||||
from .base import BasePipeline
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CogVideoPipeline(BasePipeline):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
|
|
||||||
self.scheduler = EnhancedDDIMScheduler(rescale_zero_terminal_snr=True, prediction_type="v_prediction")
|
|
||||||
self.prompter = CogPrompter()
|
|
||||||
# models
|
|
||||||
self.text_encoder: FluxTextEncoder2 = None
|
|
||||||
self.dit: CogDiT = None
|
|
||||||
self.vae_encoder: CogVAEEncoder = None
|
|
||||||
self.vae_decoder: CogVAEDecoder = None
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
|
||||||
self.text_encoder = model_manager.fetch_model("flux_text_encoder_2")
|
|
||||||
self.dit = model_manager.fetch_model("cog_dit")
|
|
||||||
self.vae_encoder = model_manager.fetch_model("cog_vae_encoder")
|
|
||||||
self.vae_decoder = model_manager.fetch_model("cog_vae_decoder")
|
|
||||||
self.prompter.fetch_models(self.text_encoder)
|
|
||||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
|
|
||||||
pipe = CogVideoPipeline(
|
|
||||||
device=model_manager.device,
|
|
||||||
torch_dtype=model_manager.torch_dtype
|
|
||||||
)
|
|
||||||
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
def tensor2video(self, frames):
|
|
||||||
frames = rearrange(frames, "C T H W -> T H W C")
|
|
||||||
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
|
||||||
frames = [Image.fromarray(frame) for frame in frames]
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, positive=True):
|
|
||||||
prompt_emb = self.prompter.encode_prompt(prompt, device=self.device, positive=positive)
|
|
||||||
return {"prompt_emb": prompt_emb}
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents):
|
|
||||||
return {"image_rotary_emb": self.dit.prepare_rotary_positional_embeddings(latents.shape[3], latents.shape[4], latents.shape[2], device=self.device)}
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
negative_prompt="",
|
|
||||||
input_video=None,
|
|
||||||
cfg_scale=7.0,
|
|
||||||
denoising_strength=1.0,
|
|
||||||
num_frames=49,
|
|
||||||
height=480,
|
|
||||||
width=720,
|
|
||||||
num_inference_steps=20,
|
|
||||||
tiled=False,
|
|
||||||
tile_size=(60, 90),
|
|
||||||
tile_stride=(30, 45),
|
|
||||||
seed=None,
|
|
||||||
progress_bar_cmd=tqdm,
|
|
||||||
progress_bar_st=None,
|
|
||||||
):
|
|
||||||
height, width = self.check_resize_height_width(height, width)
|
|
||||||
|
|
||||||
# Tiler parameters
|
|
||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
||||||
|
|
||||||
# Prepare scheduler
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
|
||||||
|
|
||||||
# Prepare latent tensors
|
|
||||||
noise = self.generate_noise((1, 16, num_frames // 4 + 1, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
|
|
||||||
|
|
||||||
if denoising_strength == 1.0:
|
|
||||||
latents = noise.clone()
|
|
||||||
else:
|
|
||||||
input_video = self.preprocess_images(input_video)
|
|
||||||
input_video = torch.stack(input_video, dim=2)
|
|
||||||
latents = self.vae_encoder.encode_video(input_video, **tiler_kwargs, progress_bar=progress_bar_cmd).to(dtype=self.torch_dtype)
|
|
||||||
latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
|
|
||||||
if not tiled: latents = latents.to(self.device)
|
|
||||||
|
|
||||||
# Encode prompt
|
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
|
||||||
if cfg_scale != 1.0:
|
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
|
||||||
|
|
||||||
# Extra input
|
|
||||||
extra_input = self.prepare_extra_input(latents)
|
|
||||||
|
|
||||||
# Denoise
|
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
# Classifier-free guidance
|
|
||||||
noise_pred_posi = self.dit(
|
|
||||||
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
|
|
||||||
)
|
|
||||||
if cfg_scale != 1.0:
|
|
||||||
noise_pred_nega = self.dit(
|
|
||||||
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
|
|
||||||
)
|
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
||||||
else:
|
|
||||||
noise_pred = noise_pred_posi
|
|
||||||
|
|
||||||
# DDIM
|
|
||||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
|
||||||
|
|
||||||
# Update progress bar
|
|
||||||
if progress_bar_st is not None:
|
|
||||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
||||||
|
|
||||||
# Decode image
|
|
||||||
video = self.vae_decoder.decode_video(latents.to("cpu"), **tiler_kwargs, progress_bar=progress_bar_cmd)
|
|
||||||
video = self.tensor2video(video[0])
|
|
||||||
|
|
||||||
return video
|
|
||||||
@@ -22,10 +22,6 @@ def lets_dance(
|
|||||||
device = "cuda",
|
device = "cuda",
|
||||||
vram_limit_level = 0,
|
vram_limit_level = 0,
|
||||||
):
|
):
|
||||||
# 0. Text embedding alignment (only for video processing)
|
|
||||||
if encoder_hidden_states.shape[0] != sample.shape[0]:
|
|
||||||
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
|
|
||||||
|
|
||||||
# 1. ControlNet
|
# 1. ControlNet
|
||||||
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
|
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
|
||||||
# I leave it here because I intend to do something interesting on the ControlNets.
|
# I leave it here because I intend to do something interesting on the ControlNets.
|
||||||
@@ -54,7 +50,7 @@ def lets_dance(
|
|||||||
additional_res_stack = None
|
additional_res_stack = None
|
||||||
|
|
||||||
# 2. time
|
# 2. time
|
||||||
time_emb = unet.time_proj(timestep).to(sample.dtype)
|
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||||
time_emb = unet.time_embedding(time_emb)
|
time_emb = unet.time_embedding(time_emb)
|
||||||
|
|
||||||
# 3. pre-process
|
# 3. pre-process
|
||||||
@@ -136,42 +132,8 @@ def lets_dance_xl(
|
|||||||
device = "cuda",
|
device = "cuda",
|
||||||
vram_limit_level = 0,
|
vram_limit_level = 0,
|
||||||
):
|
):
|
||||||
# 0. Text embedding alignment (only for video processing)
|
|
||||||
if encoder_hidden_states.shape[0] != sample.shape[0]:
|
|
||||||
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
|
|
||||||
if add_text_embeds.shape[0] != sample.shape[0]:
|
|
||||||
add_text_embeds = add_text_embeds.repeat(sample.shape[0], 1)
|
|
||||||
|
|
||||||
# 1. ControlNet
|
|
||||||
controlnet_insert_block_id = 22
|
|
||||||
if controlnet is not None and controlnet_frames is not None:
|
|
||||||
res_stacks = []
|
|
||||||
# process controlnet frames with batch
|
|
||||||
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
|
|
||||||
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
|
|
||||||
res_stack = controlnet(
|
|
||||||
sample[batch_id: batch_id_],
|
|
||||||
timestep,
|
|
||||||
encoder_hidden_states[batch_id: batch_id_],
|
|
||||||
controlnet_frames[:, batch_id: batch_id_],
|
|
||||||
add_time_id=add_time_id,
|
|
||||||
add_text_embeds=add_text_embeds,
|
|
||||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
|
||||||
unet=unet, # for Kolors, some modules in ControlNets will be replaced.
|
|
||||||
)
|
|
||||||
if vram_limit_level >= 1:
|
|
||||||
res_stack = [res.cpu() for res in res_stack]
|
|
||||||
res_stacks.append(res_stack)
|
|
||||||
# concat the residual
|
|
||||||
additional_res_stack = []
|
|
||||||
for i in range(len(res_stacks[0])):
|
|
||||||
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
|
|
||||||
additional_res_stack.append(res)
|
|
||||||
else:
|
|
||||||
additional_res_stack = None
|
|
||||||
|
|
||||||
# 2. time
|
# 2. time
|
||||||
t_emb = unet.time_proj(timestep).to(sample.dtype)
|
t_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||||
t_emb = unet.time_embedding(t_emb)
|
t_emb = unet.time_embedding(t_emb)
|
||||||
|
|
||||||
time_embeds = unet.add_time_proj(add_time_id)
|
time_embeds = unet.add_time_proj(add_time_id)
|
||||||
@@ -185,36 +147,16 @@ def lets_dance_xl(
|
|||||||
# 3. pre-process
|
# 3. pre-process
|
||||||
height, width = sample.shape[2], sample.shape[3]
|
height, width = sample.shape[2], sample.shape[3]
|
||||||
hidden_states = unet.conv_in(sample)
|
hidden_states = unet.conv_in(sample)
|
||||||
text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states)
|
text_emb = encoder_hidden_states
|
||||||
res_stack = [hidden_states]
|
res_stack = [hidden_states]
|
||||||
|
|
||||||
# 4. blocks
|
# 4. blocks
|
||||||
for block_id, block in enumerate(unet.blocks):
|
for block_id, block in enumerate(unet.blocks):
|
||||||
# 4.1 UNet
|
hidden_states, time_emb, text_emb, res_stack = block(
|
||||||
if isinstance(block, PushBlock):
|
hidden_states, time_emb, text_emb, res_stack,
|
||||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||||
if vram_limit_level>=1:
|
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {})
|
||||||
res_stack[-1] = res_stack[-1].cpu()
|
)
|
||||||
elif isinstance(block, PopBlock):
|
|
||||||
if vram_limit_level>=1:
|
|
||||||
res_stack[-1] = res_stack[-1].to(device)
|
|
||||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
|
||||||
else:
|
|
||||||
hidden_states_input = hidden_states
|
|
||||||
hidden_states_output = []
|
|
||||||
for batch_id in range(0, sample.shape[0], unet_batch_size):
|
|
||||||
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
|
|
||||||
hidden_states, _, _, _ = block(
|
|
||||||
hidden_states_input[batch_id: batch_id_],
|
|
||||||
time_emb[batch_id: batch_id_],
|
|
||||||
text_emb[batch_id: batch_id_],
|
|
||||||
res_stack,
|
|
||||||
cross_frame_attention=cross_frame_attention,
|
|
||||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
|
|
||||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
|
||||||
)
|
|
||||||
hidden_states_output.append(hidden_states)
|
|
||||||
hidden_states = torch.concat(hidden_states_output, dim=0)
|
|
||||||
# 4.2 AnimateDiff
|
# 4.2 AnimateDiff
|
||||||
if motion_modules is not None:
|
if motion_modules is not None:
|
||||||
if block_id in motion_modules.call_block_id:
|
if block_id in motion_modules.call_block_id:
|
||||||
@@ -223,10 +165,6 @@ def lets_dance_xl(
|
|||||||
hidden_states, time_emb, text_emb, res_stack,
|
hidden_states, time_emb, text_emb, res_stack,
|
||||||
batch_size=1
|
batch_size=1
|
||||||
)
|
)
|
||||||
# 4.3 ControlNet
|
|
||||||
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
|
|
||||||
hidden_states += additional_res_stack.pop().to(device)
|
|
||||||
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
|
|
||||||
|
|
||||||
# 5. output
|
# 5. output
|
||||||
hidden_states = unet.conv_norm_out(hidden_states)
|
hidden_states = unet.conv_norm_out(hidden_states)
|
||||||
|
|||||||
@@ -1,544 +0,0 @@
|
|||||||
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
|
|
||||||
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
|
||||||
from ..prompters import FluxPrompter
|
|
||||||
from ..schedulers import FlowMatchScheduler
|
|
||||||
from .base import BasePipeline
|
|
||||||
from typing import List
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from ..models.tiler import FastTileWorker
|
|
||||||
from transformers import SiglipVisionModel
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
|
|
||||||
class FluxImagePipeline(BasePipeline):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
|
|
||||||
self.scheduler = FlowMatchScheduler()
|
|
||||||
self.prompter = FluxPrompter()
|
|
||||||
# models
|
|
||||||
self.text_encoder_1: SD3TextEncoder1 = None
|
|
||||||
self.text_encoder_2: FluxTextEncoder2 = None
|
|
||||||
self.dit: FluxDiT = None
|
|
||||||
self.vae_decoder: FluxVAEDecoder = None
|
|
||||||
self.vae_encoder: FluxVAEEncoder = None
|
|
||||||
self.controlnet: FluxMultiControlNetManager = None
|
|
||||||
self.ipadapter: FluxIpAdapter = None
|
|
||||||
self.ipadapter_image_encoder: SiglipVisionModel = None
|
|
||||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
|
||||||
|
|
||||||
|
|
||||||
def denoising_model(self):
|
|
||||||
return self.dit
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[]):
|
|
||||||
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
|
||||||
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
|
|
||||||
self.dit = model_manager.fetch_model("flux_dit")
|
|
||||||
self.vae_decoder = model_manager.fetch_model("flux_vae_decoder")
|
|
||||||
self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
|
|
||||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
|
||||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
||||||
self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes)
|
|
||||||
|
|
||||||
# ControlNets
|
|
||||||
controlnet_units = []
|
|
||||||
for config in controlnet_config_units:
|
|
||||||
controlnet_unit = ControlNetUnit(
|
|
||||||
Annotator(config.processor_id, device=self.device, skip_processor=config.skip_processor),
|
|
||||||
model_manager.fetch_model("flux_controlnet", config.model_path),
|
|
||||||
config.scale
|
|
||||||
)
|
|
||||||
controlnet_units.append(controlnet_unit)
|
|
||||||
self.controlnet = FluxMultiControlNetManager(controlnet_units)
|
|
||||||
|
|
||||||
# IP-Adapters
|
|
||||||
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
|
||||||
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
|
|
||||||
pipe = FluxImagePipeline(
|
|
||||||
device=model_manager.device if device is None else device,
|
|
||||||
torch_dtype=model_manager.torch_dtype,
|
|
||||||
)
|
|
||||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
image = self.vae_output_to_image(image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
|
|
||||||
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
|
||||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
|
||||||
)
|
|
||||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None, guidance=1.0):
|
|
||||||
latent_image_ids = self.dit.prepare_image_ids(latents)
|
|
||||||
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
|
|
||||||
return {"image_ids": latent_image_ids, "guidance": guidance}
|
|
||||||
|
|
||||||
|
|
||||||
def apply_controlnet_mask_on_latents(self, latents, mask):
|
|
||||||
mask = (self.preprocess_image(mask) + 1) / 2
|
|
||||||
mask = mask.mean(dim=1, keepdim=True)
|
|
||||||
mask = mask.to(dtype=self.torch_dtype, device=self.device)
|
|
||||||
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
|
|
||||||
latents = torch.concat([latents, mask], dim=1)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def apply_controlnet_mask_on_image(self, image, mask):
|
|
||||||
mask = mask.resize(image.size)
|
|
||||||
mask = self.preprocess_image(mask).mean(dim=[0, 1])
|
|
||||||
image = np.array(image)
|
|
||||||
image[mask > 0] = 0
|
|
||||||
image = Image.fromarray(image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_controlnet_input(self, controlnet_image, controlnet_inpaint_mask, tiler_kwargs):
|
|
||||||
if isinstance(controlnet_image, Image.Image):
|
|
||||||
controlnet_image = [controlnet_image] * len(self.controlnet.processors)
|
|
||||||
|
|
||||||
controlnet_frames = []
|
|
||||||
for i in range(len(self.controlnet.processors)):
|
|
||||||
# image annotator
|
|
||||||
image = self.controlnet.process_image(controlnet_image[i], processor_id=i)[0]
|
|
||||||
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
|
|
||||||
image = self.apply_controlnet_mask_on_image(image, controlnet_inpaint_mask)
|
|
||||||
|
|
||||||
# image to tensor
|
|
||||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
|
|
||||||
# vae encoder
|
|
||||||
image = self.encode_image(image, **tiler_kwargs)
|
|
||||||
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
|
|
||||||
image = self.apply_controlnet_mask_on_latents(image, controlnet_inpaint_mask)
|
|
||||||
|
|
||||||
# store it
|
|
||||||
controlnet_frames.append(image)
|
|
||||||
return controlnet_frames
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_ipadapter_inputs(self, images, height=384, width=384):
|
|
||||||
images = [image.convert("RGB").resize((width, height), resample=3) for image in images]
|
|
||||||
images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
|
|
||||||
return torch.cat(images, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.):
|
|
||||||
# inpaint noise
|
|
||||||
inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id]
|
|
||||||
# merge noise
|
|
||||||
weight = torch.ones_like(inpaint_noise)
|
|
||||||
inpaint_noise[fg_mask] = pred_noise[fg_mask]
|
|
||||||
inpaint_noise[bg_mask] += pred_noise[bg_mask] * background_weight
|
|
||||||
weight[bg_mask] += background_weight
|
|
||||||
inpaint_noise /= weight
|
|
||||||
return inpaint_noise
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_masks(self, masks, height, width, dim):
|
|
||||||
out_masks = []
|
|
||||||
for mask in masks:
|
|
||||||
mask = self.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
|
|
||||||
mask = mask.repeat(1, dim, 1, 1).to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
out_masks.append(mask)
|
|
||||||
return out_masks
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, enable_eligen_inpaint=False):
|
|
||||||
fg_mask, bg_mask = None, None
|
|
||||||
if enable_eligen_inpaint:
|
|
||||||
masks_ = deepcopy(entity_masks)
|
|
||||||
fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_])
|
|
||||||
fg_masks = (fg_masks > 0).float()
|
|
||||||
fg_mask = fg_masks.sum(dim=0, keepdim=True).repeat(1, 16, 1, 1) > 0
|
|
||||||
bg_mask = ~fg_mask
|
|
||||||
entity_masks = self.preprocess_masks(entity_masks, height//8, width//8, 1)
|
|
||||||
entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
|
|
||||||
entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0)
|
|
||||||
return entity_prompts, entity_masks, fg_mask, bg_mask
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_latents(self, input_image, height, width, seed, tiled, tile_size, tile_stride):
|
|
||||||
if input_image is not None:
|
|
||||||
self.load_models_to_device(['vae_encoder'])
|
|
||||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
input_latents = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0])
|
|
||||||
else:
|
|
||||||
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
input_latents = None
|
|
||||||
return latents, input_latents
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_ipadapter(self, ipadapter_images, ipadapter_scale):
|
|
||||||
if ipadapter_images is not None:
|
|
||||||
self.load_models_to_device(['ipadapter_image_encoder'])
|
|
||||||
ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images)
|
|
||||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output
|
|
||||||
self.load_models_to_device(['ipadapter'])
|
|
||||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
|
||||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
|
||||||
else:
|
|
||||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
|
||||||
return ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_controlnet(self, controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative):
|
|
||||||
if controlnet_image is not None:
|
|
||||||
self.load_models_to_device(['vae_encoder'])
|
|
||||||
controlnet_kwargs_posi = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
|
|
||||||
if len(masks) > 0 and controlnet_inpaint_mask is not None:
|
|
||||||
print("The controlnet_inpaint_mask will be overridden by masks.")
|
|
||||||
local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks]
|
|
||||||
else:
|
|
||||||
local_controlnet_kwargs = None
|
|
||||||
else:
|
|
||||||
controlnet_kwargs_posi, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks)
|
|
||||||
controlnet_kwargs_nega = controlnet_kwargs_posi if enable_controlnet_on_negative else {}
|
|
||||||
return controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale):
|
|
||||||
if eligen_entity_masks is not None:
|
|
||||||
entity_prompt_emb_posi, entity_masks_posi, fg_mask, bg_mask = self.prepare_entity_inputs(eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint)
|
|
||||||
if enable_eligen_on_negative and cfg_scale != 1.0:
|
|
||||||
entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1)
|
|
||||||
entity_masks_nega = entity_masks_posi
|
|
||||||
else:
|
|
||||||
entity_prompt_emb_nega, entity_masks_nega = None, None
|
|
||||||
else:
|
|
||||||
entity_prompt_emb_posi, entity_masks_posi, entity_prompt_emb_nega, entity_masks_nega = None, None, None, None
|
|
||||||
fg_mask, bg_mask = None, None
|
|
||||||
eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi}
|
|
||||||
eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega}
|
|
||||||
return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
|
|
||||||
# Extend prompt
|
|
||||||
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
|
|
||||||
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
|
|
||||||
|
|
||||||
# Encode prompts
|
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
|
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
|
|
||||||
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
|
||||||
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
# Prompt
|
|
||||||
prompt,
|
|
||||||
negative_prompt="",
|
|
||||||
cfg_scale=1.0,
|
|
||||||
embedded_guidance=3.5,
|
|
||||||
t5_sequence_length=512,
|
|
||||||
# Image
|
|
||||||
input_image=None,
|
|
||||||
denoising_strength=1.0,
|
|
||||||
height=1024,
|
|
||||||
width=1024,
|
|
||||||
seed=None,
|
|
||||||
# Steps
|
|
||||||
num_inference_steps=30,
|
|
||||||
# local prompts
|
|
||||||
local_prompts=(),
|
|
||||||
masks=(),
|
|
||||||
mask_scales=(),
|
|
||||||
# ControlNet
|
|
||||||
controlnet_image=None,
|
|
||||||
controlnet_inpaint_mask=None,
|
|
||||||
enable_controlnet_on_negative=False,
|
|
||||||
# IP-Adapter
|
|
||||||
ipadapter_images=None,
|
|
||||||
ipadapter_scale=1.0,
|
|
||||||
# EliGen
|
|
||||||
eligen_entity_prompts=None,
|
|
||||||
eligen_entity_masks=None,
|
|
||||||
enable_eligen_on_negative=False,
|
|
||||||
enable_eligen_inpaint=False,
|
|
||||||
# TeaCache
|
|
||||||
tea_cache_l1_thresh=None,
|
|
||||||
# Tile
|
|
||||||
tiled=False,
|
|
||||||
tile_size=128,
|
|
||||||
tile_stride=64,
|
|
||||||
# Progress bar
|
|
||||||
progress_bar_cmd=tqdm,
|
|
||||||
progress_bar_st=None,
|
|
||||||
):
|
|
||||||
height, width = self.check_resize_height_width(height, width)
|
|
||||||
|
|
||||||
# Tiler parameters
|
|
||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
||||||
|
|
||||||
# Prepare scheduler
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
||||||
|
|
||||||
# Prepare latent tensors
|
|
||||||
latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
|
|
||||||
|
|
||||||
# Prompt
|
|
||||||
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
|
|
||||||
|
|
||||||
# Extra input
|
|
||||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
|
||||||
|
|
||||||
# Entity control
|
|
||||||
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
|
|
||||||
|
|
||||||
# IP-Adapter
|
|
||||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale)
|
|
||||||
|
|
||||||
# ControlNets
|
|
||||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
|
||||||
|
|
||||||
# TeaCache
|
|
||||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
|
||||||
|
|
||||||
# Denoise
|
|
||||||
self.load_models_to_device(['dit', 'controlnet'])
|
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
# Positive side
|
|
||||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
|
||||||
hidden_states=latents, timestep=timestep,
|
|
||||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
|
|
||||||
)
|
|
||||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
|
||||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
|
||||||
special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Inpaint
|
|
||||||
if enable_eligen_inpaint:
|
|
||||||
noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
|
|
||||||
|
|
||||||
# Classifier-free guidance
|
|
||||||
if cfg_scale != 1.0:
|
|
||||||
# Negative side
|
|
||||||
noise_pred_nega = lets_dance_flux(
|
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
|
||||||
hidden_states=latents, timestep=timestep,
|
|
||||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
|
|
||||||
)
|
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
||||||
else:
|
|
||||||
noise_pred = noise_pred_posi
|
|
||||||
|
|
||||||
# Iterate
|
|
||||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
|
||||||
|
|
||||||
# UI
|
|
||||||
if progress_bar_st is not None:
|
|
||||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
||||||
|
|
||||||
# Decode image
|
|
||||||
self.load_models_to_device(['vae_decoder'])
|
|
||||||
image = self.decode_image(latents, **tiler_kwargs)
|
|
||||||
|
|
||||||
# Offload all models
|
|
||||||
self.load_models_to_device([])
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
class TeaCache:
|
|
||||||
def __init__(self, num_inference_steps, rel_l1_thresh):
|
|
||||||
self.num_inference_steps = num_inference_steps
|
|
||||||
self.step = 0
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
self.previous_modulated_input = None
|
|
||||||
self.rel_l1_thresh = rel_l1_thresh
|
|
||||||
self.previous_residual = None
|
|
||||||
self.previous_hidden_states = None
|
|
||||||
|
|
||||||
def check(self, dit: FluxDiT, hidden_states, conditioning):
|
|
||||||
inp = hidden_states.clone()
|
|
||||||
temb_ = conditioning.clone()
|
|
||||||
modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_)
|
|
||||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
|
||||||
should_calc = True
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
else:
|
|
||||||
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
|
|
||||||
rescale_func = np.poly1d(coefficients)
|
|
||||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
|
||||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
|
||||||
should_calc = False
|
|
||||||
else:
|
|
||||||
should_calc = True
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
self.previous_modulated_input = modulated_inp
|
|
||||||
self.step += 1
|
|
||||||
if self.step == self.num_inference_steps:
|
|
||||||
self.step = 0
|
|
||||||
if should_calc:
|
|
||||||
self.previous_hidden_states = hidden_states.clone()
|
|
||||||
return not should_calc
|
|
||||||
|
|
||||||
def store(self, hidden_states):
|
|
||||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
|
||||||
self.previous_hidden_states = None
|
|
||||||
|
|
||||||
def update(self, hidden_states):
|
|
||||||
hidden_states = hidden_states + self.previous_residual
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def lets_dance_flux(
|
|
||||||
dit: FluxDiT,
|
|
||||||
controlnet: FluxMultiControlNetManager = None,
|
|
||||||
hidden_states=None,
|
|
||||||
timestep=None,
|
|
||||||
prompt_emb=None,
|
|
||||||
pooled_prompt_emb=None,
|
|
||||||
guidance=None,
|
|
||||||
text_ids=None,
|
|
||||||
image_ids=None,
|
|
||||||
controlnet_frames=None,
|
|
||||||
tiled=False,
|
|
||||||
tile_size=128,
|
|
||||||
tile_stride=64,
|
|
||||||
entity_prompt_emb=None,
|
|
||||||
entity_masks=None,
|
|
||||||
ipadapter_kwargs_list={},
|
|
||||||
tea_cache: TeaCache = None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
if tiled:
|
|
||||||
def flux_forward_fn(hl, hr, wl, wr):
|
|
||||||
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
|
|
||||||
return lets_dance_flux(
|
|
||||||
dit=dit,
|
|
||||||
controlnet=controlnet,
|
|
||||||
hidden_states=hidden_states[:, :, hl: hr, wl: wr],
|
|
||||||
timestep=timestep,
|
|
||||||
prompt_emb=prompt_emb,
|
|
||||||
pooled_prompt_emb=pooled_prompt_emb,
|
|
||||||
guidance=guidance,
|
|
||||||
text_ids=text_ids,
|
|
||||||
image_ids=None,
|
|
||||||
controlnet_frames=tiled_controlnet_frames,
|
|
||||||
tiled=False,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
return FastTileWorker().tiled_forward(
|
|
||||||
flux_forward_fn,
|
|
||||||
hidden_states,
|
|
||||||
tile_size=tile_size,
|
|
||||||
tile_stride=tile_stride,
|
|
||||||
tile_device=hidden_states.device,
|
|
||||||
tile_dtype=hidden_states.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ControlNet
|
|
||||||
if controlnet is not None and controlnet_frames is not None:
|
|
||||||
controlnet_extra_kwargs = {
|
|
||||||
"hidden_states": hidden_states,
|
|
||||||
"timestep": timestep,
|
|
||||||
"prompt_emb": prompt_emb,
|
|
||||||
"pooled_prompt_emb": pooled_prompt_emb,
|
|
||||||
"guidance": guidance,
|
|
||||||
"text_ids": text_ids,
|
|
||||||
"image_ids": image_ids,
|
|
||||||
"tiled": tiled,
|
|
||||||
"tile_size": tile_size,
|
|
||||||
"tile_stride": tile_stride,
|
|
||||||
}
|
|
||||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
|
||||||
controlnet_frames, **controlnet_extra_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if image_ids is None:
|
|
||||||
image_ids = dit.prepare_image_ids(hidden_states)
|
|
||||||
|
|
||||||
conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb)
|
|
||||||
if dit.guidance_embedder is not None:
|
|
||||||
guidance = guidance * 1000
|
|
||||||
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
|
|
||||||
|
|
||||||
height, width = hidden_states.shape[-2:]
|
|
||||||
hidden_states = dit.patchify(hidden_states)
|
|
||||||
hidden_states = dit.x_embedder(hidden_states)
|
|
||||||
|
|
||||||
if entity_prompt_emb is not None and entity_masks is not None:
|
|
||||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
|
||||||
else:
|
|
||||||
prompt_emb = dit.context_embedder(prompt_emb)
|
|
||||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
|
||||||
attention_mask = None
|
|
||||||
|
|
||||||
# TeaCache
|
|
||||||
if tea_cache is not None:
|
|
||||||
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
|
|
||||||
else:
|
|
||||||
tea_cache_update = False
|
|
||||||
|
|
||||||
if tea_cache_update:
|
|
||||||
hidden_states = tea_cache.update(hidden_states)
|
|
||||||
else:
|
|
||||||
# Joint Blocks
|
|
||||||
for block_id, block in enumerate(dit.blocks):
|
|
||||||
hidden_states, prompt_emb = block(
|
|
||||||
hidden_states,
|
|
||||||
prompt_emb,
|
|
||||||
conditioning,
|
|
||||||
image_rotary_emb,
|
|
||||||
attention_mask,
|
|
||||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
|
||||||
)
|
|
||||||
# ControlNet
|
|
||||||
if controlnet is not None and controlnet_frames is not None:
|
|
||||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
|
||||||
|
|
||||||
# Single Blocks
|
|
||||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
||||||
num_joint_blocks = len(dit.blocks)
|
|
||||||
for block_id, block in enumerate(dit.single_blocks):
|
|
||||||
hidden_states, prompt_emb = block(
|
|
||||||
hidden_states,
|
|
||||||
prompt_emb,
|
|
||||||
conditioning,
|
|
||||||
image_rotary_emb,
|
|
||||||
attention_mask,
|
|
||||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
|
||||||
)
|
|
||||||
# ControlNet
|
|
||||||
if controlnet is not None and controlnet_frames is not None:
|
|
||||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
|
||||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
|
||||||
|
|
||||||
if tea_cache is not None:
|
|
||||||
tea_cache.store(hidden_states)
|
|
||||||
|
|
||||||
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
|
||||||
hidden_states = dit.final_proj_out(hidden_states)
|
|
||||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
@@ -3,11 +3,11 @@ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, Hunyuan
|
|||||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
||||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
||||||
from ..models import ModelManager
|
from ..models import ModelManager
|
||||||
from ..prompters import HunyuanDiTPrompter
|
from ..prompts import HunyuanDiTPrompter
|
||||||
from ..schedulers import EnhancedDDIMScheduler
|
from ..schedulers import EnhancedDDIMScheduler
|
||||||
from .base import BasePipeline
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@@ -122,12 +122,14 @@ class ImageSizeManager:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiTImagePipeline(BasePipeline):
|
class HunyuanDiTImagePipeline(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
|
super().__init__()
|
||||||
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
|
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
|
||||||
self.prompter = HunyuanDiTPrompter()
|
self.prompter = HunyuanDiTPrompter()
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
self.image_size_manager = ImageSizeManager()
|
self.image_size_manager = ImageSizeManager()
|
||||||
# models
|
# models
|
||||||
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
|
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
|
||||||
@@ -135,63 +137,44 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
|||||||
self.dit: HunyuanDiT = None
|
self.dit: HunyuanDiT = None
|
||||||
self.vae_decoder: SDXLVAEDecoder = None
|
self.vae_decoder: SDXLVAEDecoder = None
|
||||||
self.vae_encoder: SDXLVAEEncoder = None
|
self.vae_encoder: SDXLVAEEncoder = None
|
||||||
self.model_names = ['text_encoder', 'text_encoder_t5', 'dit', 'vae_decoder', 'vae_encoder']
|
|
||||||
|
|
||||||
|
|
||||||
def denoising_model(self):
|
def fetch_main_models(self, model_manager: ModelManager):
|
||||||
return self.dit
|
self.text_encoder = model_manager.hunyuan_dit_clip_text_encoder
|
||||||
|
self.text_encoder_t5 = model_manager.hunyuan_dit_t5_text_encoder
|
||||||
|
self.dit = model_manager.hunyuan_dit
|
||||||
|
self.vae_decoder = model_manager.vae_decoder
|
||||||
|
self.vae_encoder = model_manager.vae_encoder
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
def fetch_prompter(self, model_manager: ModelManager):
|
||||||
# Main models
|
self.prompter.load_from_model_manager(model_manager)
|
||||||
self.text_encoder = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
|
|
||||||
self.text_encoder_t5 = model_manager.fetch_model("hunyuan_dit_t5_text_encoder")
|
|
||||||
self.dit = model_manager.fetch_model("hunyuan_dit")
|
|
||||||
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
|
||||||
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
|
||||||
self.prompter.fetch_models(self.text_encoder, self.text_encoder_t5)
|
|
||||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
|
def from_model_manager(model_manager: ModelManager):
|
||||||
pipe = HunyuanDiTImagePipeline(
|
pipe = HunyuanDiTImagePipeline(
|
||||||
device=model_manager.device if device is None else device,
|
device=model_manager.device,
|
||||||
torch_dtype=model_manager.torch_dtype,
|
torch_dtype=model_manager.torch_dtype,
|
||||||
)
|
)
|
||||||
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
pipe.fetch_main_models(model_manager)
|
||||||
|
pipe.fetch_prompter(model_manager)
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
def preprocess_image(self, image):
|
||||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
image = self.vae_output_to_image(image)
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=1, positive=True):
|
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||||
text_emb, text_emb_mask, text_emb_t5, text_emb_mask_t5 = self.prompter.encode_prompt(
|
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||||
prompt,
|
image = image.cpu().permute(1, 2, 0).numpy()
|
||||||
clip_skip=clip_skip,
|
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||||
clip_skip_2=clip_skip_2,
|
return image
|
||||||
positive=positive,
|
|
||||||
device=self.device
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"text_emb": text_emb,
|
|
||||||
"text_emb_mask": text_emb_mask,
|
|
||||||
"text_emb_t5": text_emb_t5,
|
|
||||||
"text_emb_mask_t5": text_emb_mask_t5
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None, tiled=False, tile_size=64, tile_stride=32):
|
def prepare_extra_input(self, height=1024, width=1024, tiled=False, tile_size=64, tile_stride=32, batch_size=1):
|
||||||
batch_size, height, width = latents.shape[0], latents.shape[2] * 8, latents.shape[3] * 8
|
|
||||||
if tiled:
|
if tiled:
|
||||||
height, width = tile_size * 16, tile_size * 16
|
height, width = tile_size * 16, tile_size * 16
|
||||||
image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
|
image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
|
||||||
@@ -210,14 +193,12 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
local_prompts=[],
|
|
||||||
masks=[],
|
|
||||||
mask_scales=[],
|
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
cfg_scale=7.5,
|
cfg_scale=7.5,
|
||||||
clip_skip=1,
|
clip_skip=1,
|
||||||
clip_skip_2=1,
|
clip_skip_2=1,
|
||||||
input_image=None,
|
input_image=None,
|
||||||
|
reference_images=[],
|
||||||
reference_strengths=[0.4],
|
reference_strengths=[0.4],
|
||||||
denoising_strength=1.0,
|
denoising_strength=1.0,
|
||||||
height=1024,
|
height=1024,
|
||||||
@@ -226,48 +207,80 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
|||||||
tiled=False,
|
tiled=False,
|
||||||
tile_size=64,
|
tile_size=64,
|
||||||
tile_stride=32,
|
tile_stride=32,
|
||||||
seed=None,
|
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
progress_bar_st=None,
|
progress_bar_st=None,
|
||||||
):
|
):
|
||||||
height, width = self.check_resize_height_width(height, width)
|
|
||||||
|
|
||||||
# Prepare scheduler
|
# Prepare scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||||
|
|
||||||
# Prepare latent tensors
|
# Prepare latent tensors
|
||||||
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||||
if input_image is not None:
|
if input_image is not None:
|
||||||
self.load_models_to_device(['vae_encoder'])
|
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
|
|
||||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||||
else:
|
else:
|
||||||
latents = noise.clone()
|
latents = noise.clone()
|
||||||
|
|
||||||
|
# Prepare reference latents
|
||||||
|
reference_latents = []
|
||||||
|
for reference_image in reference_images:
|
||||||
|
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
reference_latents.append(self.vae_encoder(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype))
|
||||||
|
|
||||||
# Encode prompts
|
# Encode prompts
|
||||||
self.load_models_to_device(['text_encoder', 'text_encoder_t5'])
|
prompt_emb_posi, attention_mask_posi, prompt_emb_t5_posi, attention_mask_t5_posi = self.prompter.encode_prompt(
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
self.text_encoder,
|
||||||
|
self.text_encoder_t5,
|
||||||
|
prompt,
|
||||||
|
clip_skip=clip_skip,
|
||||||
|
clip_skip_2=clip_skip_2,
|
||||||
|
positive=True,
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
prompt_emb_nega, attention_mask_nega, prompt_emb_t5_nega, attention_mask_t5_nega = self.prompter.encode_prompt(
|
||||||
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
|
self.text_encoder,
|
||||||
|
self.text_encoder_t5,
|
||||||
|
negative_prompt,
|
||||||
|
clip_skip=clip_skip,
|
||||||
|
clip_skip_2=clip_skip_2,
|
||||||
|
positive=False,
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare positional id
|
# Prepare positional id
|
||||||
extra_input = self.prepare_extra_input(latents, tiled, tile_size)
|
extra_input = self.prepare_extra_input(height, width, tiled, tile_size)
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
self.load_models_to_device(['dit'])
|
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
|
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
|
||||||
|
# In-context reference
|
||||||
|
for reference_latents_, reference_strength in zip(reference_latents, reference_strengths):
|
||||||
|
if progress_id < num_inference_steps * reference_strength:
|
||||||
|
noisy_reference_latents = self.scheduler.add_noise(reference_latents_, noise, self.scheduler.timesteps[progress_id])
|
||||||
|
self.dit(
|
||||||
|
noisy_reference_latents,
|
||||||
|
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
|
||||||
|
timestep,
|
||||||
|
**extra_input,
|
||||||
|
to_cache=True
|
||||||
|
)
|
||||||
# Positive side
|
# Positive side
|
||||||
inference_callback = lambda prompt_emb_posi: self.dit(latents, timestep=timestep, **prompt_emb_posi, **extra_input)
|
noise_pred_posi = self.dit(
|
||||||
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
latents,
|
||||||
|
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
|
||||||
|
timestep,
|
||||||
|
**extra_input,
|
||||||
|
)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
# Negative side
|
# Negative side
|
||||||
noise_pred_nega = self.dit(
|
noise_pred_nega = self.dit(
|
||||||
latents, timestep=timestep, **prompt_emb_nega, **extra_input,
|
latents,
|
||||||
|
prompt_emb_nega, prompt_emb_t5_nega, attention_mask_nega, attention_mask_t5_nega,
|
||||||
|
timestep,
|
||||||
|
**extra_input
|
||||||
)
|
)
|
||||||
# Classifier-free guidance
|
# Classifier-free guidance
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
@@ -280,9 +293,6 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
|||||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||||
|
|
||||||
# Decode image
|
# Decode image
|
||||||
self.load_models_to_device(['vae_decoder'])
|
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
# Offload all models
|
|
||||||
self.load_models_to_device([])
|
|
||||||
return image
|
return image
|
||||||
@@ -1,265 +0,0 @@
|
|||||||
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder
|
|
||||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
|
||||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
|
||||||
from ..schedulers.flow_match import FlowMatchScheduler
|
|
||||||
from .base import BasePipeline
|
|
||||||
from ..prompters import HunyuanVideoPrompter
|
|
||||||
import torch
|
|
||||||
from einops import rearrange
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoPipeline(BasePipeline):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
||||||
self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
|
|
||||||
self.prompter = HunyuanVideoPrompter()
|
|
||||||
self.text_encoder_1: SD3TextEncoder1 = None
|
|
||||||
self.text_encoder_2: HunyuanVideoLLMEncoder = None
|
|
||||||
self.dit: HunyuanVideoDiT = None
|
|
||||||
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
|
||||||
self.vae_encoder: HunyuanVideoVAEEncoder = None
|
|
||||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
|
|
||||||
self.vram_management = False
|
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self):
|
|
||||||
self.vram_management = True
|
|
||||||
self.enable_cpu_offload()
|
|
||||||
self.text_encoder_2.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
|
|
||||||
self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager):
|
|
||||||
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
|
||||||
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
|
|
||||||
self.dit = model_manager.fetch_model("hunyuan_video_dit")
|
|
||||||
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
|
|
||||||
self.vae_encoder = model_manager.fetch_model("hunyuan_video_vae_encoder")
|
|
||||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, enable_vram_management=True):
|
|
||||||
if device is None: device = model_manager.device
|
|
||||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
|
||||||
pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
|
||||||
pipe.fetch_models(model_manager)
|
|
||||||
if enable_vram_management:
|
|
||||||
pipe.enable_vram_management()
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
|
|
||||||
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
|
|
||||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
|
|
||||||
)
|
|
||||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None, guidance=1.0):
|
|
||||||
freqs_cos, freqs_sin = self.dit.prepare_freqs(latents)
|
|
||||||
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
|
|
||||||
return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance}
|
|
||||||
|
|
||||||
|
|
||||||
def tensor2video(self, frames):
|
|
||||||
frames = rearrange(frames, "C T H W -> T H W C")
|
|
||||||
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
|
||||||
frames = [Image.fromarray(frame) for frame in frames]
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
def encode_video(self, frames, tile_size=(17, 30, 30), tile_stride=(12, 20, 20)):
|
|
||||||
tile_size = ((tile_size[0] - 1) * 4 + 1, tile_size[1] * 8, tile_size[2] * 8)
|
|
||||||
tile_stride = (tile_stride[0] * 4, tile_stride[1] * 8, tile_stride[2] * 8)
|
|
||||||
latents = self.vae_encoder.encode_video(frames, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
negative_prompt="",
|
|
||||||
input_video=None,
|
|
||||||
denoising_strength=1.0,
|
|
||||||
seed=None,
|
|
||||||
rand_device=None,
|
|
||||||
height=720,
|
|
||||||
width=1280,
|
|
||||||
num_frames=129,
|
|
||||||
embedded_guidance=6.0,
|
|
||||||
cfg_scale=1.0,
|
|
||||||
num_inference_steps=30,
|
|
||||||
tea_cache_l1_thresh=None,
|
|
||||||
tile_size=(17, 30, 30),
|
|
||||||
tile_stride=(12, 20, 20),
|
|
||||||
step_processor=None,
|
|
||||||
progress_bar_cmd=lambda x: x,
|
|
||||||
progress_bar_st=None,
|
|
||||||
):
|
|
||||||
# Tiler parameters
|
|
||||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
|
||||||
|
|
||||||
# Scheduler
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
||||||
|
|
||||||
# Initialize noise
|
|
||||||
rand_device = self.device if rand_device is None else rand_device
|
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
|
|
||||||
if input_video is not None:
|
|
||||||
self.load_models_to_device(['vae_encoder'])
|
|
||||||
input_video = self.preprocess_images(input_video)
|
|
||||||
input_video = torch.stack(input_video, dim=2)
|
|
||||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
|
||||||
else:
|
|
||||||
latents = noise
|
|
||||||
|
|
||||||
# Encode prompts
|
|
||||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
|
||||||
if cfg_scale != 1.0:
|
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
|
||||||
|
|
||||||
# Extra input
|
|
||||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
|
||||||
|
|
||||||
# TeaCache
|
|
||||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
|
||||||
|
|
||||||
# Denoise
|
|
||||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
|
||||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
|
||||||
|
|
||||||
# Inference
|
|
||||||
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
|
||||||
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
|
||||||
if cfg_scale != 1.0:
|
|
||||||
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
||||||
else:
|
|
||||||
noise_pred = noise_pred_posi
|
|
||||||
|
|
||||||
# (Experimental feature, may be removed in the future)
|
|
||||||
if step_processor is not None:
|
|
||||||
self.load_models_to_device(['vae_decoder'])
|
|
||||||
rendered_frames = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents, to_final=True)
|
|
||||||
rendered_frames = self.vae_decoder.decode_video(rendered_frames, **tiler_kwargs)
|
|
||||||
rendered_frames = self.tensor2video(rendered_frames[0])
|
|
||||||
rendered_frames = step_processor(rendered_frames, original_frames=input_video)
|
|
||||||
self.load_models_to_device(['vae_encoder'])
|
|
||||||
rendered_frames = self.preprocess_images(rendered_frames)
|
|
||||||
rendered_frames = torch.stack(rendered_frames, dim=2)
|
|
||||||
target_latents = self.encode_video(rendered_frames).to(dtype=self.torch_dtype, device=self.device)
|
|
||||||
noise_pred = self.scheduler.return_to_timestep(self.scheduler.timesteps[progress_id], latents, target_latents)
|
|
||||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
|
||||||
|
|
||||||
# Scheduler
|
|
||||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
|
||||||
|
|
||||||
# Decode
|
|
||||||
self.load_models_to_device(['vae_decoder'])
|
|
||||||
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
|
|
||||||
self.load_models_to_device([])
|
|
||||||
frames = self.tensor2video(frames[0])
|
|
||||||
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TeaCache:
|
|
||||||
def __init__(self, num_inference_steps, rel_l1_thresh):
|
|
||||||
self.num_inference_steps = num_inference_steps
|
|
||||||
self.step = 0
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
self.previous_modulated_input = None
|
|
||||||
self.rel_l1_thresh = rel_l1_thresh
|
|
||||||
self.previous_residual = None
|
|
||||||
self.previous_hidden_states = None
|
|
||||||
|
|
||||||
def check(self, dit: HunyuanVideoDiT, img, vec):
|
|
||||||
img_ = img.clone()
|
|
||||||
vec_ = vec.clone()
|
|
||||||
img_mod1_shift, img_mod1_scale, _, _, _, _ = dit.double_blocks[0].component_a.mod(vec_).chunk(6, dim=-1)
|
|
||||||
normed_inp = dit.double_blocks[0].component_a.norm1(img_)
|
|
||||||
modulated_inp = normed_inp * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
|
|
||||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
|
||||||
should_calc = True
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
else:
|
|
||||||
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
|
||||||
rescale_func = np.poly1d(coefficients)
|
|
||||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
|
||||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
|
||||||
should_calc = False
|
|
||||||
else:
|
|
||||||
should_calc = True
|
|
||||||
self.accumulated_rel_l1_distance = 0
|
|
||||||
self.previous_modulated_input = modulated_inp
|
|
||||||
self.step += 1
|
|
||||||
if self.step == self.num_inference_steps:
|
|
||||||
self.step = 0
|
|
||||||
if should_calc:
|
|
||||||
self.previous_hidden_states = img.clone()
|
|
||||||
return not should_calc
|
|
||||||
|
|
||||||
def store(self, hidden_states):
|
|
||||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
|
||||||
self.previous_hidden_states = None
|
|
||||||
|
|
||||||
def update(self, hidden_states):
|
|
||||||
hidden_states = hidden_states + self.previous_residual
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def lets_dance_hunyuan_video(
|
|
||||||
dit: HunyuanVideoDiT,
|
|
||||||
x: torch.Tensor,
|
|
||||||
t: torch.Tensor,
|
|
||||||
prompt_emb: torch.Tensor = None,
|
|
||||||
text_mask: torch.Tensor = None,
|
|
||||||
pooled_prompt_emb: torch.Tensor = None,
|
|
||||||
freqs_cos: torch.Tensor = None,
|
|
||||||
freqs_sin: torch.Tensor = None,
|
|
||||||
guidance: torch.Tensor = None,
|
|
||||||
tea_cache: TeaCache = None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
B, C, T, H, W = x.shape
|
|
||||||
|
|
||||||
vec = dit.time_in(t, dtype=torch.float32) + dit.vector_in(pooled_prompt_emb) + dit.guidance_in(guidance * 1000, dtype=torch.float32)
|
|
||||||
img = dit.img_in(x)
|
|
||||||
txt = dit.txt_in(prompt_emb, t, text_mask)
|
|
||||||
|
|
||||||
# TeaCache
|
|
||||||
if tea_cache is not None:
|
|
||||||
tea_cache_update = tea_cache.check(dit, img, vec)
|
|
||||||
else:
|
|
||||||
tea_cache_update = False
|
|
||||||
|
|
||||||
if tea_cache_update:
|
|
||||||
print("TeaCache skip forward.")
|
|
||||||
img = tea_cache.update(img)
|
|
||||||
else:
|
|
||||||
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
|
||||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
|
||||||
|
|
||||||
x = torch.concat([img, txt], dim=1)
|
|
||||||
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
|
||||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
|
||||||
img = x[:, :-256]
|
|
||||||
|
|
||||||
if tea_cache is not None:
|
|
||||||
tea_cache.store(img)
|
|
||||||
img = dit.final_layer(img, vec)
|
|
||||||
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
|
||||||
return img
|
|
||||||
@@ -1,289 +0,0 @@
|
|||||||
from ..models.omnigen import OmniGenTransformer
|
|
||||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
|
||||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
|
||||||
from ..models.model_manager import ModelManager
|
|
||||||
from ..prompters.omnigen_prompter import OmniGenPrompter
|
|
||||||
from ..schedulers import FlowMatchScheduler
|
|
||||||
from .base import BasePipeline
|
|
||||||
from typing import Optional, Dict, Any, Tuple, List
|
|
||||||
from transformers.cache_utils import DynamicCache
|
|
||||||
import torch, os
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OmniGenCache(DynamicCache):
|
|
||||||
def __init__(self,
|
|
||||||
num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
|
|
||||||
offload_kv_cache = False
|
|
||||||
raise RuntimeError("OffloadedCache can only be used with a GPU")
|
|
||||||
super().__init__()
|
|
||||||
self.original_device = []
|
|
||||||
self.prefetch_stream = torch.cuda.Stream()
|
|
||||||
self.num_tokens_for_img = num_tokens_for_img
|
|
||||||
self.offload_kv_cache = offload_kv_cache
|
|
||||||
|
|
||||||
def prefetch_layer(self, layer_idx: int):
|
|
||||||
"Starts prefetching the next layer cache"
|
|
||||||
if layer_idx < len(self):
|
|
||||||
with torch.cuda.stream(self.prefetch_stream):
|
|
||||||
# Prefetch next layer tensors to GPU
|
|
||||||
device = self.original_device[layer_idx]
|
|
||||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
|
|
||||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
|
|
||||||
|
|
||||||
|
|
||||||
def evict_previous_layer(self, layer_idx: int):
|
|
||||||
"Moves the previous layer cache to the CPU"
|
|
||||||
if len(self) > 2:
|
|
||||||
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
|
|
||||||
if layer_idx == 0:
|
|
||||||
prev_layer_idx = -1
|
|
||||||
else:
|
|
||||||
prev_layer_idx = (layer_idx - 1) % len(self)
|
|
||||||
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
|
||||||
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
|
||||||
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
|
|
||||||
if layer_idx < len(self):
|
|
||||||
if self.offload_kv_cache:
|
|
||||||
# Evict the previous layer if necessary
|
|
||||||
torch.cuda.current_stream().synchronize()
|
|
||||||
self.evict_previous_layer(layer_idx)
|
|
||||||
# Load current layer cache to its original device if not already there
|
|
||||||
original_device = self.original_device[layer_idx]
|
|
||||||
# self.prefetch_stream.synchronize(original_device)
|
|
||||||
torch.cuda.synchronize(self.prefetch_stream)
|
|
||||||
key_tensor = self.key_cache[layer_idx]
|
|
||||||
value_tensor = self.value_cache[layer_idx]
|
|
||||||
|
|
||||||
# Prefetch the next layer
|
|
||||||
self.prefetch_layer((layer_idx + 1) % len(self))
|
|
||||||
else:
|
|
||||||
key_tensor = self.key_cache[layer_idx]
|
|
||||||
value_tensor = self.value_cache[layer_idx]
|
|
||||||
return (key_tensor, value_tensor)
|
|
||||||
else:
|
|
||||||
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
|
||||||
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
key_states: torch.Tensor,
|
|
||||||
value_states: torch.Tensor,
|
|
||||||
layer_idx: int,
|
|
||||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
|
||||||
Parameters:
|
|
||||||
key_states (`torch.Tensor`):
|
|
||||||
The new key states to cache.
|
|
||||||
value_states (`torch.Tensor`):
|
|
||||||
The new value states to cache.
|
|
||||||
layer_idx (`int`):
|
|
||||||
The index of the layer to cache the states for.
|
|
||||||
cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
||||||
Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
|
|
||||||
Return:
|
|
||||||
A tuple containing the updated key and value states.
|
|
||||||
"""
|
|
||||||
# Update the cache
|
|
||||||
if len(self.key_cache) < layer_idx:
|
|
||||||
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
|
|
||||||
elif len(self.key_cache) == layer_idx:
|
|
||||||
# only cache the states for condition tokens
|
|
||||||
key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
|
|
||||||
value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
|
|
||||||
|
|
||||||
# Update the number of seen tokens
|
|
||||||
if layer_idx == 0:
|
|
||||||
self._seen_tokens += key_states.shape[-2]
|
|
||||||
|
|
||||||
self.key_cache.append(key_states)
|
|
||||||
self.value_cache.append(value_states)
|
|
||||||
self.original_device.append(key_states.device)
|
|
||||||
if self.offload_kv_cache:
|
|
||||||
self.evict_previous_layer(layer_idx)
|
|
||||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
||||||
else:
|
|
||||||
# only cache the states for condition tokens
|
|
||||||
key_tensor, value_tensor = self[layer_idx]
|
|
||||||
k = torch.cat([key_tensor, key_states], dim=-2)
|
|
||||||
v = torch.cat([value_tensor, value_states], dim=-2)
|
|
||||||
return k, v
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OmnigenImagePipeline(BasePipeline):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
||||||
self.scheduler = FlowMatchScheduler(num_train_timesteps=1, shift=1, inverse_timesteps=True, sigma_min=0, sigma_max=1)
|
|
||||||
# models
|
|
||||||
self.vae_decoder: SDXLVAEDecoder = None
|
|
||||||
self.vae_encoder: SDXLVAEEncoder = None
|
|
||||||
self.transformer: OmniGenTransformer = None
|
|
||||||
self.prompter: OmniGenPrompter = None
|
|
||||||
self.model_names = ['transformer', 'vae_decoder', 'vae_encoder']
|
|
||||||
|
|
||||||
|
|
||||||
def denoising_model(self):
|
|
||||||
return self.transformer
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
|
||||||
# Main models
|
|
||||||
self.transformer, model_path = model_manager.fetch_model("omnigen_transformer", require_model_path=True)
|
|
||||||
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
|
||||||
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
|
||||||
self.prompter = OmniGenPrompter.from_pretrained(os.path.dirname(model_path))
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
|
|
||||||
pipe = OmnigenImagePipeline(
|
|
||||||
device=model_manager.device if device is None else device,
|
|
||||||
torch_dtype=model_manager.torch_dtype,
|
|
||||||
)
|
|
||||||
pipe.fetch_models(model_manager, prompt_refiner_classes=[])
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def encode_images(self, images, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
latents = [self.encode_image(image.to(device=self.device), tiled, tile_size, tile_stride).to(self.torch_dtype) for image in images]
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
image = self.vae_output_to_image(image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, clip_skip=1, positive=True):
|
|
||||||
prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
|
|
||||||
return {"encoder_hidden_states": prompt_emb}
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
|
|
||||||
if isinstance(position_ids, list):
|
|
||||||
for i in range(len(position_ids)):
|
|
||||||
position_ids[i] = position_ids[i][:, -(num_tokens_for_img+1):]
|
|
||||||
else:
|
|
||||||
position_ids = position_ids[:, -(num_tokens_for_img+1):]
|
|
||||||
return position_ids
|
|
||||||
|
|
||||||
|
|
||||||
def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
|
|
||||||
if isinstance(attention_mask, list):
|
|
||||||
return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
|
|
||||||
return attention_mask[..., -(num_tokens_for_img+1):, :]
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
reference_images=[],
|
|
||||||
cfg_scale=2.0,
|
|
||||||
image_cfg_scale=2.0,
|
|
||||||
use_kv_cache=True,
|
|
||||||
offload_kv_cache=True,
|
|
||||||
input_image=None,
|
|
||||||
denoising_strength=1.0,
|
|
||||||
height=1024,
|
|
||||||
width=1024,
|
|
||||||
num_inference_steps=20,
|
|
||||||
tiled=False,
|
|
||||||
tile_size=64,
|
|
||||||
tile_stride=32,
|
|
||||||
seed=None,
|
|
||||||
progress_bar_cmd=tqdm,
|
|
||||||
progress_bar_st=None,
|
|
||||||
):
|
|
||||||
height, width = self.check_resize_height_width(height, width)
|
|
||||||
|
|
||||||
# Tiler parameters
|
|
||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
||||||
|
|
||||||
# Prepare scheduler
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
||||||
|
|
||||||
# Prepare latent tensors
|
|
||||||
if input_image is not None:
|
|
||||||
self.load_models_to_device(['vae_encoder'])
|
|
||||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
latents = self.encode_image(image, **tiler_kwargs)
|
|
||||||
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
|
||||||
else:
|
|
||||||
latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
latents = latents.repeat(3, 1, 1, 1)
|
|
||||||
|
|
||||||
# Encode prompts
|
|
||||||
input_data = self.prompter(prompt, reference_images, height=height, width=width, use_img_cfg=True, separate_cfg_input=True, use_input_image_size_as_output=False)
|
|
||||||
|
|
||||||
# Encode images
|
|
||||||
reference_latents = [self.encode_images(images, **tiler_kwargs) for images in input_data['input_pixel_values']]
|
|
||||||
|
|
||||||
# Pack all parameters
|
|
||||||
model_kwargs = dict(input_ids=[input_ids.to(self.device) for input_ids in input_data['input_ids']],
|
|
||||||
input_img_latents=reference_latents,
|
|
||||||
input_image_sizes=input_data['input_image_sizes'],
|
|
||||||
attention_mask=[attention_mask.to(self.device) for attention_mask in input_data["attention_mask"]],
|
|
||||||
position_ids=[position_ids.to(self.device) for position_ids in input_data["position_ids"]],
|
|
||||||
cfg_scale=cfg_scale,
|
|
||||||
img_cfg_scale=image_cfg_scale,
|
|
||||||
use_img_cfg=True,
|
|
||||||
use_kv_cache=use_kv_cache,
|
|
||||||
offload_model=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Denoise
|
|
||||||
self.load_models_to_device(['transformer'])
|
|
||||||
cache = [OmniGenCache(latents.size(-1)*latents.size(-2) // 4, offload_kv_cache) for _ in range(len(model_kwargs['input_ids']))] if use_kv_cache else None
|
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
||||||
timestep = timestep.unsqueeze(0).repeat(latents.shape[0]).to(self.device)
|
|
||||||
|
|
||||||
# Forward
|
|
||||||
noise_pred, cache = self.transformer.forward_with_separate_cfg(latents, timestep, past_key_values=cache, **model_kwargs)
|
|
||||||
|
|
||||||
# Scheduler
|
|
||||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
|
||||||
|
|
||||||
# Update KV cache
|
|
||||||
if progress_id == 0 and use_kv_cache:
|
|
||||||
num_tokens_for_img = latents.size(-1)*latents.size(-2) // 4
|
|
||||||
if isinstance(cache, list):
|
|
||||||
model_kwargs['input_ids'] = [None] * len(cache)
|
|
||||||
else:
|
|
||||||
model_kwargs['input_ids'] = None
|
|
||||||
model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
|
|
||||||
model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
|
|
||||||
|
|
||||||
# UI
|
|
||||||
if progress_bar_st is not None:
|
|
||||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
||||||
|
|
||||||
# Decode image
|
|
||||||
del cache
|
|
||||||
self.load_models_to_device(['vae_decoder'])
|
|
||||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
# offload all models
|
|
||||||
self.load_models_to_device([])
|
|
||||||
return image
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
import os, torch, json
|
|
||||||
from .sd_video import ModelManager, SDVideoPipeline, ControlNetConfigUnit
|
|
||||||
from ..processors.sequencial_processor import SequencialProcessor
|
|
||||||
from ..data import VideoData, save_frames, save_video
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SDVideoPipelineRunner:
|
|
||||||
def __init__(self, in_streamlit=False):
|
|
||||||
self.in_streamlit = in_streamlit
|
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
|
|
||||||
# Load models
|
|
||||||
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
|
||||||
model_manager.load_models(model_list)
|
|
||||||
pipe = SDVideoPipeline.from_model_manager(
|
|
||||||
model_manager,
|
|
||||||
[
|
|
||||||
ControlNetConfigUnit(
|
|
||||||
processor_id=unit["processor_id"],
|
|
||||||
model_path=unit["model_path"],
|
|
||||||
scale=unit["scale"]
|
|
||||||
) for unit in controlnet_units
|
|
||||||
]
|
|
||||||
)
|
|
||||||
textual_inversion_paths = []
|
|
||||||
for file_name in os.listdir(textual_inversion_folder):
|
|
||||||
if file_name.endswith(".pt") or file_name.endswith(".bin") or file_name.endswith(".pth") or file_name.endswith(".safetensors"):
|
|
||||||
textual_inversion_paths.append(os.path.join(textual_inversion_folder, file_name))
|
|
||||||
pipe.prompter.load_textual_inversions(textual_inversion_paths)
|
|
||||||
return model_manager, pipe
|
|
||||||
|
|
||||||
|
|
||||||
def load_smoother(self, model_manager, smoother_configs):
|
|
||||||
smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
|
|
||||||
return smoother
|
|
||||||
|
|
||||||
|
|
||||||
def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if self.in_streamlit:
|
|
||||||
import streamlit as st
|
|
||||||
progress_bar_st = st.progress(0.0)
|
|
||||||
output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
|
|
||||||
progress_bar_st.progress(1.0)
|
|
||||||
else:
|
|
||||||
output_video = pipe(**pipeline_inputs, smoother=smoother)
|
|
||||||
model_manager.to("cpu")
|
|
||||||
return output_video
|
|
||||||
|
|
||||||
|
|
||||||
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
|
||||||
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
|
||||||
if start_frame_id is None:
|
|
||||||
start_frame_id = 0
|
|
||||||
if end_frame_id is None:
|
|
||||||
end_frame_id = len(video)
|
|
||||||
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
|
||||||
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
|
||||||
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
|
||||||
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
|
||||||
if len(data["controlnet_frames"]) > 0:
|
|
||||||
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
|
||||||
return pipeline_inputs
|
|
||||||
|
|
||||||
|
|
||||||
def save_output(self, video, output_folder, fps, config):
|
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
|
||||||
save_frames(video, os.path.join(output_folder, "frames"))
|
|
||||||
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
|
||||||
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
|
||||||
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
|
||||||
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
|
||||||
json.dump(config, file, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
def run(self, config):
|
|
||||||
if self.in_streamlit:
|
|
||||||
import streamlit as st
|
|
||||||
if self.in_streamlit: st.markdown("Loading videos ...")
|
|
||||||
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
|
||||||
if self.in_streamlit: st.markdown("Loading videos ... done!")
|
|
||||||
if self.in_streamlit: st.markdown("Loading models ...")
|
|
||||||
model_manager, pipe = self.load_pipeline(**config["models"])
|
|
||||||
if self.in_streamlit: st.markdown("Loading models ... done!")
|
|
||||||
if "smoother_configs" in config:
|
|
||||||
if self.in_streamlit: st.markdown("Loading smoother ...")
|
|
||||||
smoother = self.load_smoother(model_manager, config["smoother_configs"])
|
|
||||||
if self.in_streamlit: st.markdown("Loading smoother ... done!")
|
|
||||||
else:
|
|
||||||
smoother = None
|
|
||||||
if self.in_streamlit: st.markdown("Synthesizing videos ...")
|
|
||||||
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
|
|
||||||
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
|
|
||||||
if self.in_streamlit: st.markdown("Saving videos ...")
|
|
||||||
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
|
||||||
if self.in_streamlit: st.markdown("Saving videos ... done!")
|
|
||||||
if self.in_streamlit: st.markdown("Finished!")
|
|
||||||
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
|
||||||
if self.in_streamlit: st.video(video_file.read())
|
|
||||||
@@ -1,147 +0,0 @@
|
|||||||
from ..models import ModelManager, SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEDecoder, SD3VAEEncoder
|
|
||||||
from ..prompters import SD3Prompter
|
|
||||||
from ..schedulers import FlowMatchScheduler
|
|
||||||
from .base import BasePipeline
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SD3ImagePipeline(BasePipeline):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
|
|
||||||
self.scheduler = FlowMatchScheduler()
|
|
||||||
self.prompter = SD3Prompter()
|
|
||||||
# models
|
|
||||||
self.text_encoder_1: SD3TextEncoder1 = None
|
|
||||||
self.text_encoder_2: SD3TextEncoder2 = None
|
|
||||||
self.text_encoder_3: SD3TextEncoder3 = None
|
|
||||||
self.dit: SD3DiT = None
|
|
||||||
self.vae_decoder: SD3VAEDecoder = None
|
|
||||||
self.vae_encoder: SD3VAEEncoder = None
|
|
||||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'text_encoder_3', 'dit', 'vae_decoder', 'vae_encoder']
|
|
||||||
|
|
||||||
|
|
||||||
def denoising_model(self):
|
|
||||||
return self.dit
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
|
||||||
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
|
||||||
self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2")
|
|
||||||
self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
|
|
||||||
self.dit = model_manager.fetch_model("sd3_dit")
|
|
||||||
self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder")
|
|
||||||
self.vae_encoder = model_manager.fetch_model("sd3_vae_encoder")
|
|
||||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2, self.text_encoder_3)
|
|
||||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
|
|
||||||
pipe = SD3ImagePipeline(
|
|
||||||
device=model_manager.device if device is None else device,
|
|
||||||
torch_dtype=model_manager.torch_dtype,
|
|
||||||
)
|
|
||||||
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
image = self.vae_output_to_image(image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=77):
|
|
||||||
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
|
|
||||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
|
||||||
)
|
|
||||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
local_prompts=[],
|
|
||||||
masks=[],
|
|
||||||
mask_scales=[],
|
|
||||||
negative_prompt="",
|
|
||||||
cfg_scale=7.5,
|
|
||||||
input_image=None,
|
|
||||||
denoising_strength=1.0,
|
|
||||||
height=1024,
|
|
||||||
width=1024,
|
|
||||||
num_inference_steps=20,
|
|
||||||
t5_sequence_length=77,
|
|
||||||
tiled=False,
|
|
||||||
tile_size=128,
|
|
||||||
tile_stride=64,
|
|
||||||
seed=None,
|
|
||||||
progress_bar_cmd=tqdm,
|
|
||||||
progress_bar_st=None,
|
|
||||||
):
|
|
||||||
height, width = self.check_resize_height_width(height, width)
|
|
||||||
|
|
||||||
# Tiler parameters
|
|
||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
||||||
|
|
||||||
# Prepare scheduler
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
||||||
|
|
||||||
# Prepare latent tensors
|
|
||||||
if input_image is not None:
|
|
||||||
self.load_models_to_device(['vae_encoder'])
|
|
||||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
latents = self.encode_image(image, **tiler_kwargs)
|
|
||||||
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
|
||||||
else:
|
|
||||||
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
|
|
||||||
# Encode prompts
|
|
||||||
self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3'])
|
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True, t5_sequence_length=t5_sequence_length)
|
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
|
|
||||||
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
|
||||||
|
|
||||||
# Denoise
|
|
||||||
self.load_models_to_device(['dit'])
|
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
# Classifier-free guidance
|
|
||||||
inference_callback = lambda prompt_emb_posi: self.dit(
|
|
||||||
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
|
|
||||||
)
|
|
||||||
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
|
||||||
noise_pred_nega = self.dit(
|
|
||||||
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
|
|
||||||
)
|
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
||||||
|
|
||||||
# DDIM
|
|
||||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
|
||||||
|
|
||||||
# UI
|
|
||||||
if progress_bar_st is not None:
|
|
||||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
||||||
|
|
||||||
# Decode image
|
|
||||||
self.load_models_to_device(['vae_decoder'])
|
|
||||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
# offload all models
|
|
||||||
self.load_models_to_device([])
|
|
||||||
return image
|
|
||||||
@@ -1,191 +0,0 @@
|
|||||||
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
|
|
||||||
from ..models.model_manager import ModelManager
|
|
||||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
|
||||||
from ..prompters import SDPrompter
|
|
||||||
from ..schedulers import EnhancedDDIMScheduler
|
|
||||||
from .base import BasePipeline
|
|
||||||
from .dancer import lets_dance
|
|
||||||
from typing import List
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SDImagePipeline(BasePipeline):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
||||||
self.scheduler = EnhancedDDIMScheduler()
|
|
||||||
self.prompter = SDPrompter()
|
|
||||||
# models
|
|
||||||
self.text_encoder: SDTextEncoder = None
|
|
||||||
self.unet: SDUNet = None
|
|
||||||
self.vae_decoder: SDVAEDecoder = None
|
|
||||||
self.vae_encoder: SDVAEEncoder = None
|
|
||||||
self.controlnet: MultiControlNetManager = None
|
|
||||||
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
|
|
||||||
self.ipadapter: SDIpAdapter = None
|
|
||||||
self.model_names = ['text_encoder', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
|
|
||||||
|
|
||||||
|
|
||||||
def denoising_model(self):
|
|
||||||
return self.unet
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
||||||
# Main models
|
|
||||||
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
|
|
||||||
self.unet = model_manager.fetch_model("sd_unet")
|
|
||||||
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
|
|
||||||
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
|
|
||||||
self.prompter.fetch_models(self.text_encoder)
|
|
||||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
||||||
|
|
||||||
# ControlNets
|
|
||||||
controlnet_units = []
|
|
||||||
for config in controlnet_config_units:
|
|
||||||
controlnet_unit = ControlNetUnit(
|
|
||||||
Annotator(config.processor_id, device=self.device),
|
|
||||||
model_manager.fetch_model("sd_controlnet", config.model_path),
|
|
||||||
config.scale
|
|
||||||
)
|
|
||||||
controlnet_units.append(controlnet_unit)
|
|
||||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
|
||||||
|
|
||||||
# IP-Adapters
|
|
||||||
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
|
|
||||||
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
|
|
||||||
pipe = SDImagePipeline(
|
|
||||||
device=model_manager.device if device is None else device,
|
|
||||||
torch_dtype=model_manager.torch_dtype,
|
|
||||||
)
|
|
||||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[])
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
image = self.vae_output_to_image(image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, clip_skip=1, positive=True):
|
|
||||||
prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
|
|
||||||
return {"encoder_hidden_states": prompt_emb}
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
local_prompts=[],
|
|
||||||
masks=[],
|
|
||||||
mask_scales=[],
|
|
||||||
negative_prompt="",
|
|
||||||
cfg_scale=7.5,
|
|
||||||
clip_skip=1,
|
|
||||||
input_image=None,
|
|
||||||
ipadapter_images=None,
|
|
||||||
ipadapter_scale=1.0,
|
|
||||||
controlnet_image=None,
|
|
||||||
denoising_strength=1.0,
|
|
||||||
height=512,
|
|
||||||
width=512,
|
|
||||||
num_inference_steps=20,
|
|
||||||
tiled=False,
|
|
||||||
tile_size=64,
|
|
||||||
tile_stride=32,
|
|
||||||
seed=None,
|
|
||||||
progress_bar_cmd=tqdm,
|
|
||||||
progress_bar_st=None,
|
|
||||||
):
|
|
||||||
height, width = self.check_resize_height_width(height, width)
|
|
||||||
|
|
||||||
# Tiler parameters
|
|
||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
||||||
|
|
||||||
# Prepare scheduler
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
||||||
|
|
||||||
# Prepare latent tensors
|
|
||||||
if input_image is not None:
|
|
||||||
self.load_models_to_device(['vae_encoder'])
|
|
||||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
latents = self.encode_image(image, **tiler_kwargs)
|
|
||||||
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
|
||||||
else:
|
|
||||||
latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
|
|
||||||
# Encode prompts
|
|
||||||
self.load_models_to_device(['text_encoder'])
|
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
|
||||||
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, positive=True) for prompt_local in local_prompts]
|
|
||||||
|
|
||||||
# IP-Adapter
|
|
||||||
if ipadapter_images is not None:
|
|
||||||
self.load_models_to_device(['ipadapter_image_encoder'])
|
|
||||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
|
||||||
self.load_models_to_device(['ipadapter'])
|
|
||||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
|
||||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
|
||||||
else:
|
|
||||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
|
||||||
|
|
||||||
# Prepare ControlNets
|
|
||||||
if controlnet_image is not None:
|
|
||||||
self.load_models_to_device(['controlnet'])
|
|
||||||
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
controlnet_image = controlnet_image.unsqueeze(1)
|
|
||||||
controlnet_kwargs = {"controlnet_frames": controlnet_image}
|
|
||||||
else:
|
|
||||||
controlnet_kwargs = {"controlnet_frames": None}
|
|
||||||
|
|
||||||
# Denoise
|
|
||||||
self.load_models_to_device(['controlnet', 'unet'])
|
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
# Classifier-free guidance
|
|
||||||
inference_callback = lambda prompt_emb_posi: lets_dance(
|
|
||||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
|
||||||
sample=latents, timestep=timestep,
|
|
||||||
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
|
||||||
noise_pred_nega = lets_dance(
|
|
||||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
|
||||||
sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
||||||
|
|
||||||
# DDIM
|
|
||||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
|
||||||
|
|
||||||
# UI
|
|
||||||
if progress_bar_st is not None:
|
|
||||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
||||||
|
|
||||||
# Decode image
|
|
||||||
self.load_models_to_device(['vae_decoder'])
|
|
||||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
# offload all models
|
|
||||||
self.load_models_to_device([])
|
|
||||||
return image
|
|
||||||
@@ -1,269 +0,0 @@
|
|||||||
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder, SDMotionModel
|
|
||||||
from ..models.model_manager import ModelManager
|
|
||||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
|
||||||
from ..prompters import SDPrompter
|
|
||||||
from ..schedulers import EnhancedDDIMScheduler
|
|
||||||
from .sd_image import SDImagePipeline
|
|
||||||
from .dancer import lets_dance
|
|
||||||
from typing import List
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def lets_dance_with_long_video(
|
|
||||||
unet: SDUNet,
|
|
||||||
motion_modules: SDMotionModel = None,
|
|
||||||
controlnet: MultiControlNetManager = None,
|
|
||||||
sample = None,
|
|
||||||
timestep = None,
|
|
||||||
encoder_hidden_states = None,
|
|
||||||
ipadapter_kwargs_list = {},
|
|
||||||
controlnet_frames = None,
|
|
||||||
unet_batch_size = 1,
|
|
||||||
controlnet_batch_size = 1,
|
|
||||||
cross_frame_attention = False,
|
|
||||||
tiled=False,
|
|
||||||
tile_size=64,
|
|
||||||
tile_stride=32,
|
|
||||||
device="cuda",
|
|
||||||
animatediff_batch_size=16,
|
|
||||||
animatediff_stride=8,
|
|
||||||
):
|
|
||||||
num_frames = sample.shape[0]
|
|
||||||
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
|
|
||||||
|
|
||||||
for batch_id in range(0, num_frames, animatediff_stride):
|
|
||||||
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
|
|
||||||
|
|
||||||
# process this batch
|
|
||||||
hidden_states_batch = lets_dance(
|
|
||||||
unet, motion_modules, controlnet,
|
|
||||||
sample[batch_id: batch_id_].to(device),
|
|
||||||
timestep,
|
|
||||||
encoder_hidden_states,
|
|
||||||
ipadapter_kwargs_list=ipadapter_kwargs_list,
|
|
||||||
controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
|
|
||||||
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
|
||||||
cross_frame_attention=cross_frame_attention,
|
|
||||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, device=device
|
|
||||||
).cpu()
|
|
||||||
|
|
||||||
# update hidden_states
|
|
||||||
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
|
|
||||||
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
|
|
||||||
hidden_states, num = hidden_states_output[i]
|
|
||||||
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
|
|
||||||
hidden_states_output[i] = (hidden_states, num + bias)
|
|
||||||
|
|
||||||
if batch_id_ == num_frames:
|
|
||||||
break
|
|
||||||
|
|
||||||
# output
|
|
||||||
hidden_states = torch.stack([h for h, _ in hidden_states_output])
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SDVideoPipeline(SDImagePipeline):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
|
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
||||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
|
|
||||||
self.prompter = SDPrompter()
|
|
||||||
# models
|
|
||||||
self.text_encoder: SDTextEncoder = None
|
|
||||||
self.unet: SDUNet = None
|
|
||||||
self.vae_decoder: SDVAEDecoder = None
|
|
||||||
self.vae_encoder: SDVAEEncoder = None
|
|
||||||
self.controlnet: MultiControlNetManager = None
|
|
||||||
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
|
|
||||||
self.ipadapter: SDIpAdapter = None
|
|
||||||
self.motion_modules: SDMotionModel = None
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
||||||
# Main models
|
|
||||||
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
|
|
||||||
self.unet = model_manager.fetch_model("sd_unet")
|
|
||||||
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
|
|
||||||
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
|
|
||||||
self.prompter.fetch_models(self.text_encoder)
|
|
||||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
||||||
|
|
||||||
# ControlNets
|
|
||||||
controlnet_units = []
|
|
||||||
for config in controlnet_config_units:
|
|
||||||
controlnet_unit = ControlNetUnit(
|
|
||||||
Annotator(config.processor_id, device=self.device),
|
|
||||||
model_manager.fetch_model("sd_controlnet", config.model_path),
|
|
||||||
config.scale
|
|
||||||
)
|
|
||||||
controlnet_units.append(controlnet_unit)
|
|
||||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
|
||||||
|
|
||||||
# IP-Adapters
|
|
||||||
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
|
|
||||||
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
|
|
||||||
|
|
||||||
# Motion Modules
|
|
||||||
self.motion_modules = model_manager.fetch_model("sd_motion_modules")
|
|
||||||
if self.motion_modules is None:
|
|
||||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
||||||
pipe = SDVideoPipeline(
|
|
||||||
device=model_manager.device,
|
|
||||||
torch_dtype=model_manager.torch_dtype,
|
|
||||||
)
|
|
||||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
images = [
|
|
||||||
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
for frame_id in range(latents.shape[0])
|
|
||||||
]
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
latents = []
|
|
||||||
for image in processed_images:
|
|
||||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
latents.append(latent.cpu())
|
|
||||||
latents = torch.concat(latents, dim=0)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
negative_prompt="",
|
|
||||||
cfg_scale=7.5,
|
|
||||||
clip_skip=1,
|
|
||||||
num_frames=None,
|
|
||||||
input_frames=None,
|
|
||||||
ipadapter_images=None,
|
|
||||||
ipadapter_scale=1.0,
|
|
||||||
controlnet_frames=None,
|
|
||||||
denoising_strength=1.0,
|
|
||||||
height=512,
|
|
||||||
width=512,
|
|
||||||
num_inference_steps=20,
|
|
||||||
animatediff_batch_size = 16,
|
|
||||||
animatediff_stride = 8,
|
|
||||||
unet_batch_size = 1,
|
|
||||||
controlnet_batch_size = 1,
|
|
||||||
cross_frame_attention = False,
|
|
||||||
smoother=None,
|
|
||||||
smoother_progress_ids=[],
|
|
||||||
tiled=False,
|
|
||||||
tile_size=64,
|
|
||||||
tile_stride=32,
|
|
||||||
seed=None,
|
|
||||||
progress_bar_cmd=tqdm,
|
|
||||||
progress_bar_st=None,
|
|
||||||
):
|
|
||||||
height, width = self.check_resize_height_width(height, width)
|
|
||||||
|
|
||||||
# Tiler parameters, batch size ...
|
|
||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
||||||
other_kwargs = {
|
|
||||||
"animatediff_batch_size": animatediff_batch_size, "animatediff_stride": animatediff_stride,
|
|
||||||
"unet_batch_size": unet_batch_size, "controlnet_batch_size": controlnet_batch_size,
|
|
||||||
"cross_frame_attention": cross_frame_attention,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare scheduler
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
||||||
|
|
||||||
# Prepare latent tensors
|
|
||||||
if self.motion_modules is None:
|
|
||||||
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
|
||||||
else:
|
|
||||||
noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
|
|
||||||
if input_frames is None or denoising_strength == 1.0:
|
|
||||||
latents = noise
|
|
||||||
else:
|
|
||||||
latents = self.encode_video(input_frames, **tiler_kwargs)
|
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
|
||||||
|
|
||||||
# Encode prompts
|
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
|
||||||
|
|
||||||
# IP-Adapter
|
|
||||||
if ipadapter_images is not None:
|
|
||||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
|
||||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
|
||||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
|
||||||
else:
|
|
||||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
|
||||||
|
|
||||||
# Prepare ControlNets
|
|
||||||
if controlnet_frames is not None:
|
|
||||||
if isinstance(controlnet_frames[0], list):
|
|
||||||
controlnet_frames_ = []
|
|
||||||
for processor_id in range(len(controlnet_frames)):
|
|
||||||
controlnet_frames_.append(
|
|
||||||
torch.stack([
|
|
||||||
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
|
||||||
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
|
||||||
], dim=1)
|
|
||||||
)
|
|
||||||
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
|
||||||
else:
|
|
||||||
controlnet_frames = torch.stack([
|
|
||||||
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
|
||||||
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
|
||||||
], dim=1)
|
|
||||||
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
|
|
||||||
else:
|
|
||||||
controlnet_kwargs = {"controlnet_frames": None}
|
|
||||||
|
|
||||||
# Denoise
|
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
# Classifier-free guidance
|
|
||||||
noise_pred_posi = lets_dance_with_long_video(
|
|
||||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
|
||||||
sample=latents, timestep=timestep,
|
|
||||||
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **other_kwargs, **tiler_kwargs,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
noise_pred_nega = lets_dance_with_long_video(
|
|
||||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
|
||||||
sample=latents, timestep=timestep,
|
|
||||||
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **other_kwargs, **tiler_kwargs,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
||||||
|
|
||||||
# DDIM and smoother
|
|
||||||
if smoother is not None and progress_id in smoother_progress_ids:
|
|
||||||
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
|
|
||||||
rendered_frames = self.decode_video(rendered_frames)
|
|
||||||
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
|
|
||||||
target_latents = self.encode_video(rendered_frames)
|
|
||||||
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
|
|
||||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
|
||||||
|
|
||||||
# UI
|
|
||||||
if progress_bar_st is not None:
|
|
||||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
||||||
|
|
||||||
# Decode image
|
|
||||||
output_frames = self.decode_video(latents, **tiler_kwargs)
|
|
||||||
|
|
||||||
# Post-process
|
|
||||||
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
|
|
||||||
output_frames = smoother(output_frames, original_frames=input_frames)
|
|
||||||
|
|
||||||
return output_frames
|
|
||||||
@@ -1,226 +0,0 @@
|
|||||||
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
|
||||||
from ..models.kolors_text_encoder import ChatGLMModel
|
|
||||||
from ..models.model_manager import ModelManager
|
|
||||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
|
||||||
from ..prompters import SDXLPrompter, KolorsPrompter
|
|
||||||
from ..schedulers import EnhancedDDIMScheduler
|
|
||||||
from .base import BasePipeline
|
|
||||||
from .dancer import lets_dance_xl
|
|
||||||
from typing import List
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
from einops import repeat
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLImagePipeline(BasePipeline):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
||||||
self.scheduler = EnhancedDDIMScheduler()
|
|
||||||
self.prompter = SDXLPrompter()
|
|
||||||
# models
|
|
||||||
self.text_encoder: SDXLTextEncoder = None
|
|
||||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
|
||||||
self.text_encoder_kolors: ChatGLMModel = None
|
|
||||||
self.unet: SDXLUNet = None
|
|
||||||
self.vae_decoder: SDXLVAEDecoder = None
|
|
||||||
self.vae_encoder: SDXLVAEEncoder = None
|
|
||||||
self.controlnet: MultiControlNetManager = None
|
|
||||||
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
|
||||||
self.ipadapter: SDXLIpAdapter = None
|
|
||||||
self.model_names = ['text_encoder', 'text_encoder_2', 'text_encoder_kolors', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
|
|
||||||
|
|
||||||
|
|
||||||
def denoising_model(self):
|
|
||||||
return self.unet
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
||||||
# Main models
|
|
||||||
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
|
|
||||||
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
|
|
||||||
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
|
|
||||||
self.unet = model_manager.fetch_model("sdxl_unet")
|
|
||||||
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
|
||||||
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
|
||||||
|
|
||||||
# ControlNets
|
|
||||||
controlnet_units = []
|
|
||||||
for config in controlnet_config_units:
|
|
||||||
controlnet_unit = ControlNetUnit(
|
|
||||||
Annotator(config.processor_id, device=self.device),
|
|
||||||
model_manager.fetch_model("sdxl_controlnet", config.model_path),
|
|
||||||
config.scale
|
|
||||||
)
|
|
||||||
controlnet_units.append(controlnet_unit)
|
|
||||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
|
||||||
|
|
||||||
# IP-Adapters
|
|
||||||
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
|
|
||||||
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
|
|
||||||
|
|
||||||
# Kolors
|
|
||||||
if self.text_encoder_kolors is not None:
|
|
||||||
print("Switch to Kolors. The prompter and scheduler will be replaced.")
|
|
||||||
self.prompter = KolorsPrompter()
|
|
||||||
self.prompter.fetch_models(self.text_encoder_kolors)
|
|
||||||
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
|
|
||||||
else:
|
|
||||||
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
|
|
||||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
|
|
||||||
pipe = SDXLImagePipeline(
|
|
||||||
device=model_manager.device if device is None else device,
|
|
||||||
torch_dtype=model_manager.torch_dtype,
|
|
||||||
)
|
|
||||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
image = self.vae_output_to_image(image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=2, positive=True):
|
|
||||||
add_prompt_emb, prompt_emb = self.prompter.encode_prompt(
|
|
||||||
prompt,
|
|
||||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
|
||||||
device=self.device,
|
|
||||||
positive=positive,
|
|
||||||
)
|
|
||||||
return {"encoder_hidden_states": prompt_emb, "add_text_embeds": add_prompt_emb}
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None):
|
|
||||||
height, width = latents.shape[2] * 8, latents.shape[3] * 8
|
|
||||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device).repeat(latents.shape[0])
|
|
||||||
return {"add_time_id": add_time_id}
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
local_prompts=[],
|
|
||||||
masks=[],
|
|
||||||
mask_scales=[],
|
|
||||||
negative_prompt="",
|
|
||||||
cfg_scale=7.5,
|
|
||||||
clip_skip=1,
|
|
||||||
clip_skip_2=2,
|
|
||||||
input_image=None,
|
|
||||||
ipadapter_images=None,
|
|
||||||
ipadapter_scale=1.0,
|
|
||||||
ipadapter_use_instant_style=False,
|
|
||||||
controlnet_image=None,
|
|
||||||
denoising_strength=1.0,
|
|
||||||
height=1024,
|
|
||||||
width=1024,
|
|
||||||
num_inference_steps=20,
|
|
||||||
tiled=False,
|
|
||||||
tile_size=64,
|
|
||||||
tile_stride=32,
|
|
||||||
seed=None,
|
|
||||||
progress_bar_cmd=tqdm,
|
|
||||||
progress_bar_st=None,
|
|
||||||
):
|
|
||||||
height, width = self.check_resize_height_width(height, width)
|
|
||||||
|
|
||||||
# Tiler parameters
|
|
||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
||||||
|
|
||||||
# Prepare scheduler
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
||||||
|
|
||||||
# Prepare latent tensors
|
|
||||||
if input_image is not None:
|
|
||||||
self.load_models_to_device(['vae_encoder'])
|
|
||||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
latents = self.encode_image(image, **tiler_kwargs)
|
|
||||||
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
|
||||||
else:
|
|
||||||
latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
|
||||||
|
|
||||||
# Encode prompts
|
|
||||||
self.load_models_to_device(['text_encoder', 'text_encoder_2', 'text_encoder_kolors'])
|
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
|
|
||||||
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
|
|
||||||
|
|
||||||
# IP-Adapter
|
|
||||||
if ipadapter_images is not None:
|
|
||||||
if ipadapter_use_instant_style:
|
|
||||||
self.ipadapter.set_less_adapter()
|
|
||||||
else:
|
|
||||||
self.ipadapter.set_full_adapter()
|
|
||||||
self.load_models_to_device(['ipadapter_image_encoder'])
|
|
||||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
|
||||||
self.load_models_to_device(['ipadapter'])
|
|
||||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
|
||||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
|
||||||
else:
|
|
||||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
|
||||||
|
|
||||||
# Prepare ControlNets
|
|
||||||
if controlnet_image is not None:
|
|
||||||
self.load_models_to_device(['controlnet'])
|
|
||||||
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
controlnet_image = controlnet_image.unsqueeze(1)
|
|
||||||
controlnet_kwargs = {"controlnet_frames": controlnet_image}
|
|
||||||
else:
|
|
||||||
controlnet_kwargs = {"controlnet_frames": None}
|
|
||||||
|
|
||||||
# Prepare extra input
|
|
||||||
extra_input = self.prepare_extra_input(latents)
|
|
||||||
|
|
||||||
# Denoise
|
|
||||||
self.load_models_to_device(['controlnet', 'unet'])
|
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
# Classifier-free guidance
|
|
||||||
inference_callback = lambda prompt_emb_posi: lets_dance_xl(
|
|
||||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
|
||||||
sample=latents, timestep=timestep, **extra_input,
|
|
||||||
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
|
||||||
|
|
||||||
if cfg_scale != 1.0:
|
|
||||||
noise_pred_nega = lets_dance_xl(
|
|
||||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
|
||||||
sample=latents, timestep=timestep, **extra_input,
|
|
||||||
**prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
||||||
else:
|
|
||||||
noise_pred = noise_pred_posi
|
|
||||||
|
|
||||||
# DDIM
|
|
||||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
|
||||||
|
|
||||||
# UI
|
|
||||||
if progress_bar_st is not None:
|
|
||||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
||||||
|
|
||||||
# Decode image
|
|
||||||
self.load_models_to_device(['vae_decoder'])
|
|
||||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
# offload all models
|
|
||||||
self.load_models_to_device([])
|
|
||||||
return image
|
|
||||||
@@ -1,226 +0,0 @@
|
|||||||
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder, SDXLMotionModel
|
|
||||||
from ..models.kolors_text_encoder import ChatGLMModel
|
|
||||||
from ..models.model_manager import ModelManager
|
|
||||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
|
||||||
from ..prompters import SDXLPrompter, KolorsPrompter
|
|
||||||
from ..schedulers import EnhancedDDIMScheduler
|
|
||||||
from .sdxl_image import SDXLImagePipeline
|
|
||||||
from .dancer import lets_dance_xl
|
|
||||||
from typing import List
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLVideoPipeline(SDXLImagePipeline):
|
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
|
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
||||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
|
|
||||||
self.prompter = SDXLPrompter()
|
|
||||||
# models
|
|
||||||
self.text_encoder: SDXLTextEncoder = None
|
|
||||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
|
||||||
self.text_encoder_kolors: ChatGLMModel = None
|
|
||||||
self.unet: SDXLUNet = None
|
|
||||||
self.vae_decoder: SDXLVAEDecoder = None
|
|
||||||
self.vae_encoder: SDXLVAEEncoder = None
|
|
||||||
# self.controlnet: MultiControlNetManager = None (TODO)
|
|
||||||
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
|
||||||
self.ipadapter: SDXLIpAdapter = None
|
|
||||||
self.motion_modules: SDXLMotionModel = None
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
||||||
# Main models
|
|
||||||
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
|
|
||||||
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
|
|
||||||
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
|
|
||||||
self.unet = model_manager.fetch_model("sdxl_unet")
|
|
||||||
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
|
||||||
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
|
||||||
self.prompter.fetch_models(self.text_encoder)
|
|
||||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
|
||||||
|
|
||||||
# ControlNets (TODO)
|
|
||||||
|
|
||||||
# IP-Adapters
|
|
||||||
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
|
|
||||||
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
|
|
||||||
|
|
||||||
# Motion Modules
|
|
||||||
self.motion_modules = model_manager.fetch_model("sdxl_motion_modules")
|
|
||||||
if self.motion_modules is None:
|
|
||||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
|
|
||||||
|
|
||||||
# Kolors
|
|
||||||
if self.text_encoder_kolors is not None:
|
|
||||||
print("Switch to Kolors. The prompter will be replaced.")
|
|
||||||
self.prompter = KolorsPrompter()
|
|
||||||
self.prompter.fetch_models(self.text_encoder_kolors)
|
|
||||||
# The schedulers of AniamteDiff and Kolors are incompatible. We align it with AniamteDiff.
|
|
||||||
if self.motion_modules is None:
|
|
||||||
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
|
|
||||||
else:
|
|
||||||
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
|
|
||||||
pipe = SDXLVideoPipeline(
|
|
||||||
device=model_manager.device,
|
|
||||||
torch_dtype=model_manager.torch_dtype,
|
|
||||||
)
|
|
||||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
images = [
|
|
||||||
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
for frame_id in range(latents.shape[0])
|
|
||||||
]
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
|
||||||
latents = []
|
|
||||||
for image in processed_images:
|
|
||||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
latents.append(latent.cpu())
|
|
||||||
latents = torch.concat(latents, dim=0)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
negative_prompt="",
|
|
||||||
cfg_scale=7.5,
|
|
||||||
clip_skip=1,
|
|
||||||
num_frames=None,
|
|
||||||
input_frames=None,
|
|
||||||
ipadapter_images=None,
|
|
||||||
ipadapter_scale=1.0,
|
|
||||||
ipadapter_use_instant_style=False,
|
|
||||||
controlnet_frames=None,
|
|
||||||
denoising_strength=1.0,
|
|
||||||
height=512,
|
|
||||||
width=512,
|
|
||||||
num_inference_steps=20,
|
|
||||||
animatediff_batch_size = 16,
|
|
||||||
animatediff_stride = 8,
|
|
||||||
unet_batch_size = 1,
|
|
||||||
controlnet_batch_size = 1,
|
|
||||||
cross_frame_attention = False,
|
|
||||||
smoother=None,
|
|
||||||
smoother_progress_ids=[],
|
|
||||||
tiled=False,
|
|
||||||
tile_size=64,
|
|
||||||
tile_stride=32,
|
|
||||||
seed=None,
|
|
||||||
progress_bar_cmd=tqdm,
|
|
||||||
progress_bar_st=None,
|
|
||||||
):
|
|
||||||
height, width = self.check_resize_height_width(height, width)
|
|
||||||
|
|
||||||
# Tiler parameters, batch size ...
|
|
||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
||||||
|
|
||||||
# Prepare scheduler
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
|
||||||
|
|
||||||
# Prepare latent tensors
|
|
||||||
if self.motion_modules is None:
|
|
||||||
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
|
||||||
else:
|
|
||||||
noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
|
|
||||||
if input_frames is None or denoising_strength == 1.0:
|
|
||||||
latents = noise
|
|
||||||
else:
|
|
||||||
latents = self.encode_video(input_frames, **tiler_kwargs)
|
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
|
||||||
latents = latents.to(self.device) # will be deleted for supporting long videos
|
|
||||||
|
|
||||||
# Encode prompts
|
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
|
||||||
|
|
||||||
# IP-Adapter
|
|
||||||
if ipadapter_images is not None:
|
|
||||||
if ipadapter_use_instant_style:
|
|
||||||
self.ipadapter.set_less_adapter()
|
|
||||||
else:
|
|
||||||
self.ipadapter.set_full_adapter()
|
|
||||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
|
||||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
|
||||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
|
||||||
else:
|
|
||||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
|
||||||
|
|
||||||
# Prepare ControlNets
|
|
||||||
if controlnet_frames is not None:
|
|
||||||
if isinstance(controlnet_frames[0], list):
|
|
||||||
controlnet_frames_ = []
|
|
||||||
for processor_id in range(len(controlnet_frames)):
|
|
||||||
controlnet_frames_.append(
|
|
||||||
torch.stack([
|
|
||||||
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
|
||||||
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
|
||||||
], dim=1)
|
|
||||||
)
|
|
||||||
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
|
||||||
else:
|
|
||||||
controlnet_frames = torch.stack([
|
|
||||||
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
|
||||||
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
|
||||||
], dim=1)
|
|
||||||
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
|
|
||||||
else:
|
|
||||||
controlnet_kwargs = {"controlnet_frames": None}
|
|
||||||
|
|
||||||
# Prepare extra input
|
|
||||||
extra_input = self.prepare_extra_input(latents)
|
|
||||||
|
|
||||||
# Denoise
|
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
# Classifier-free guidance
|
|
||||||
noise_pred_posi = lets_dance_xl(
|
|
||||||
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
|
||||||
sample=latents, timestep=timestep,
|
|
||||||
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **extra_input, **tiler_kwargs,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
noise_pred_nega = lets_dance_xl(
|
|
||||||
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
|
||||||
sample=latents, timestep=timestep,
|
|
||||||
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **extra_input, **tiler_kwargs,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
|
||||||
|
|
||||||
# DDIM and smoother
|
|
||||||
if smoother is not None and progress_id in smoother_progress_ids:
|
|
||||||
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
|
|
||||||
rendered_frames = self.decode_video(rendered_frames)
|
|
||||||
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
|
|
||||||
target_latents = self.encode_video(rendered_frames)
|
|
||||||
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
|
|
||||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
|
||||||
|
|
||||||
# UI
|
|
||||||
if progress_bar_st is not None:
|
|
||||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
|
||||||
|
|
||||||
# Decode image
|
|
||||||
output_frames = self.decode_video(latents, **tiler_kwargs)
|
|
||||||
|
|
||||||
# Post-process
|
|
||||||
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
|
|
||||||
output_frames = smoother(output_frames, original_frames=input_frames)
|
|
||||||
|
|
||||||
return output_frames
|
|
||||||
167
diffsynth/pipelines/stable_diffusion.py
Normal file
167
diffsynth/pipelines/stable_diffusion.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||||
|
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||||
|
from ..prompts import SDPrompter
|
||||||
|
from ..schedulers import EnhancedDDIMScheduler
|
||||||
|
from .dancer import lets_dance
|
||||||
|
from typing import List
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class SDImagePipeline(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||||
|
super().__init__()
|
||||||
|
self.scheduler = EnhancedDDIMScheduler()
|
||||||
|
self.prompter = SDPrompter()
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
# models
|
||||||
|
self.text_encoder: SDTextEncoder = None
|
||||||
|
self.unet: SDUNet = None
|
||||||
|
self.vae_decoder: SDVAEDecoder = None
|
||||||
|
self.vae_encoder: SDVAEEncoder = None
|
||||||
|
self.controlnet: MultiControlNetManager = None
|
||||||
|
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
|
||||||
|
self.ipadapter: SDIpAdapter = None
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_main_models(self, model_manager: ModelManager):
|
||||||
|
self.text_encoder = model_manager.text_encoder
|
||||||
|
self.unet = model_manager.unet
|
||||||
|
self.vae_decoder = model_manager.vae_decoder
|
||||||
|
self.vae_encoder = model_manager.vae_encoder
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||||
|
controlnet_units = []
|
||||||
|
for config in controlnet_config_units:
|
||||||
|
controlnet_unit = ControlNetUnit(
|
||||||
|
Annotator(config.processor_id),
|
||||||
|
model_manager.get_model_with_model_path(config.model_path),
|
||||||
|
config.scale
|
||||||
|
)
|
||||||
|
controlnet_units.append(controlnet_unit)
|
||||||
|
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_ipadapter(self, model_manager: ModelManager):
|
||||||
|
if "ipadapter" in model_manager.model:
|
||||||
|
self.ipadapter = model_manager.ipadapter
|
||||||
|
if "ipadapter_image_encoder" in model_manager.model:
|
||||||
|
self.ipadapter_image_encoder = model_manager.ipadapter_image_encoder
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_prompter(self, model_manager: ModelManager):
|
||||||
|
self.prompter.load_from_model_manager(model_manager)
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||||
|
pipe = SDImagePipeline(
|
||||||
|
device=model_manager.device,
|
||||||
|
torch_dtype=model_manager.torch_dtype,
|
||||||
|
)
|
||||||
|
pipe.fetch_main_models(model_manager)
|
||||||
|
pipe.fetch_prompter(model_manager)
|
||||||
|
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||||
|
pipe.fetch_ipadapter(model_manager)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(self, image):
|
||||||
|
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||||
|
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||||
|
image = image.cpu().permute(1, 2, 0).numpy()
|
||||||
|
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
clip_skip=1,
|
||||||
|
input_image=None,
|
||||||
|
ipadapter_images=None,
|
||||||
|
ipadapter_scale=1.0,
|
||||||
|
controlnet_image=None,
|
||||||
|
denoising_strength=1.0,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
num_inference_steps=20,
|
||||||
|
tiled=False,
|
||||||
|
tile_size=64,
|
||||||
|
tile_stride=32,
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
progress_bar_st=None,
|
||||||
|
):
|
||||||
|
# Prepare scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||||
|
|
||||||
|
# Prepare latent tensors
|
||||||
|
if input_image is not None:
|
||||||
|
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||||
|
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||||
|
else:
|
||||||
|
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||||
|
|
||||||
|
# Encode prompts
|
||||||
|
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True)
|
||||||
|
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False)
|
||||||
|
|
||||||
|
# IP-Adapter
|
||||||
|
if ipadapter_images is not None:
|
||||||
|
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||||
|
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
|
||||||
|
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
|
||||||
|
else:
|
||||||
|
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
|
||||||
|
|
||||||
|
# Prepare ControlNets
|
||||||
|
if controlnet_image is not None:
|
||||||
|
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
controlnet_image = controlnet_image.unsqueeze(1)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||||
|
|
||||||
|
# Classifier-free guidance
|
||||||
|
noise_pred_posi = lets_dance(
|
||||||
|
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||||
|
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image,
|
||||||
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||||
|
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
|
||||||
|
device=self.device, vram_limit_level=0
|
||||||
|
)
|
||||||
|
noise_pred_nega = lets_dance(
|
||||||
|
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||||
|
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image,
|
||||||
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||||
|
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
|
||||||
|
device=self.device, vram_limit_level=0
|
||||||
|
)
|
||||||
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
|
|
||||||
|
# DDIM
|
||||||
|
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||||
|
|
||||||
|
# UI
|
||||||
|
if progress_bar_st is not None:
|
||||||
|
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||||
|
|
||||||
|
# Decode image
|
||||||
|
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
|
||||||
|
return image
|
||||||
371
diffsynth/pipelines/stable_diffusion_video.py
Normal file
371
diffsynth/pipelines/stable_diffusion_video.py
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel
|
||||||
|
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||||
|
from ..prompts import SDPrompter
|
||||||
|
from ..schedulers import EnhancedDDIMScheduler
|
||||||
|
from ..data import VideoData, save_frames, save_video
|
||||||
|
from .dancer import lets_dance
|
||||||
|
from ..processors.sequencial_processor import SequencialProcessor
|
||||||
|
from typing import List
|
||||||
|
import torch, os, json
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def lets_dance_with_long_video(
|
||||||
|
unet: SDUNet,
|
||||||
|
motion_modules: SDMotionModel = None,
|
||||||
|
controlnet: MultiControlNetManager = None,
|
||||||
|
sample = None,
|
||||||
|
timestep = None,
|
||||||
|
encoder_hidden_states = None,
|
||||||
|
controlnet_frames = None,
|
||||||
|
animatediff_batch_size = 16,
|
||||||
|
animatediff_stride = 8,
|
||||||
|
unet_batch_size = 1,
|
||||||
|
controlnet_batch_size = 1,
|
||||||
|
cross_frame_attention = False,
|
||||||
|
device = "cuda",
|
||||||
|
vram_limit_level = 0,
|
||||||
|
):
|
||||||
|
num_frames = sample.shape[0]
|
||||||
|
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
|
||||||
|
|
||||||
|
for batch_id in range(0, num_frames, animatediff_stride):
|
||||||
|
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
|
||||||
|
|
||||||
|
# process this batch
|
||||||
|
hidden_states_batch = lets_dance(
|
||||||
|
unet, motion_modules, controlnet,
|
||||||
|
sample[batch_id: batch_id_].to(device),
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states[batch_id: batch_id_].to(device),
|
||||||
|
controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
|
||||||
|
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||||
|
cross_frame_attention=cross_frame_attention,
|
||||||
|
device=device, vram_limit_level=vram_limit_level
|
||||||
|
).cpu()
|
||||||
|
|
||||||
|
# update hidden_states
|
||||||
|
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
|
||||||
|
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
|
||||||
|
hidden_states, num = hidden_states_output[i]
|
||||||
|
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
|
||||||
|
hidden_states_output[i] = (hidden_states, num + bias)
|
||||||
|
|
||||||
|
if batch_id_ == num_frames:
|
||||||
|
break
|
||||||
|
|
||||||
|
# output
|
||||||
|
hidden_states = torch.stack([h for h, _ in hidden_states_output])
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SDVideoPipeline(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
|
||||||
|
super().__init__()
|
||||||
|
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
|
||||||
|
self.prompter = SDPrompter()
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
# models
|
||||||
|
self.text_encoder: SDTextEncoder = None
|
||||||
|
self.unet: SDUNet = None
|
||||||
|
self.vae_decoder: SDVAEDecoder = None
|
||||||
|
self.vae_encoder: SDVAEEncoder = None
|
||||||
|
self.controlnet: MultiControlNetManager = None
|
||||||
|
self.motion_modules: SDMotionModel = None
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_main_models(self, model_manager: ModelManager):
|
||||||
|
self.text_encoder = model_manager.text_encoder
|
||||||
|
self.unet = model_manager.unet
|
||||||
|
self.vae_decoder = model_manager.vae_decoder
|
||||||
|
self.vae_encoder = model_manager.vae_encoder
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||||
|
controlnet_units = []
|
||||||
|
for config in controlnet_config_units:
|
||||||
|
controlnet_unit = ControlNetUnit(
|
||||||
|
Annotator(config.processor_id),
|
||||||
|
model_manager.get_model_with_model_path(config.model_path),
|
||||||
|
config.scale
|
||||||
|
)
|
||||||
|
controlnet_units.append(controlnet_unit)
|
||||||
|
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_motion_modules(self, model_manager: ModelManager):
|
||||||
|
if "motion_modules" in model_manager.model:
|
||||||
|
self.motion_modules = model_manager.motion_modules
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_prompter(self, model_manager: ModelManager):
|
||||||
|
self.prompter.load_from_model_manager(model_manager)
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||||
|
pipe = SDVideoPipeline(
|
||||||
|
device=model_manager.device,
|
||||||
|
torch_dtype=model_manager.torch_dtype,
|
||||||
|
use_animatediff="motion_modules" in model_manager.model
|
||||||
|
)
|
||||||
|
pipe.fetch_main_models(model_manager)
|
||||||
|
pipe.fetch_motion_modules(model_manager)
|
||||||
|
pipe.fetch_prompter(model_manager)
|
||||||
|
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(self, image):
|
||||||
|
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||||
|
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||||
|
image = image.cpu().permute(1, 2, 0).numpy()
|
||||||
|
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
||||||
|
images = [
|
||||||
|
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
for frame_id in range(latents.shape[0])
|
||||||
|
]
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
||||||
|
latents = []
|
||||||
|
for image in processed_images:
|
||||||
|
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
|
||||||
|
latents.append(latent)
|
||||||
|
latents = torch.concat(latents, dim=0)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
|
||||||
|
if post_normalize:
|
||||||
|
mean, std = latents.mean(), latents.std()
|
||||||
|
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
|
||||||
|
latents = latents * contrast_enhance_scale
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
clip_skip=1,
|
||||||
|
num_frames=None,
|
||||||
|
input_frames=None,
|
||||||
|
controlnet_frames=None,
|
||||||
|
denoising_strength=1.0,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
num_inference_steps=20,
|
||||||
|
animatediff_batch_size = 16,
|
||||||
|
animatediff_stride = 8,
|
||||||
|
unet_batch_size = 1,
|
||||||
|
controlnet_batch_size = 1,
|
||||||
|
cross_frame_attention = False,
|
||||||
|
smoother=None,
|
||||||
|
smoother_progress_ids=[],
|
||||||
|
vram_limit_level=0,
|
||||||
|
post_normalize=False,
|
||||||
|
contrast_enhance_scale=1.0,
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
progress_bar_st=None,
|
||||||
|
):
|
||||||
|
# Prepare scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||||
|
|
||||||
|
# Prepare latent tensors
|
||||||
|
if self.motion_modules is None:
|
||||||
|
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||||
|
else:
|
||||||
|
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
||||||
|
if input_frames is None or denoising_strength == 1.0:
|
||||||
|
latents = noise
|
||||||
|
else:
|
||||||
|
latents = self.encode_images(input_frames)
|
||||||
|
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||||
|
|
||||||
|
# Encode prompts
|
||||||
|
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True).cpu()
|
||||||
|
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False).cpu()
|
||||||
|
prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1)
|
||||||
|
prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1)
|
||||||
|
|
||||||
|
# Prepare ControlNets
|
||||||
|
if controlnet_frames is not None:
|
||||||
|
if isinstance(controlnet_frames[0], list):
|
||||||
|
controlnet_frames_ = []
|
||||||
|
for processor_id in range(len(controlnet_frames)):
|
||||||
|
controlnet_frames_.append(
|
||||||
|
torch.stack([
|
||||||
|
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
||||||
|
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
||||||
|
], dim=1)
|
||||||
|
)
|
||||||
|
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
||||||
|
else:
|
||||||
|
controlnet_frames = torch.stack([
|
||||||
|
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
||||||
|
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||||
|
|
||||||
|
# Classifier-free guidance
|
||||||
|
noise_pred_posi = lets_dance_with_long_video(
|
||||||
|
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||||
|
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
|
||||||
|
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
|
||||||
|
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||||
|
cross_frame_attention=cross_frame_attention,
|
||||||
|
device=self.device, vram_limit_level=vram_limit_level
|
||||||
|
)
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
noise_pred_nega = lets_dance_with_long_video(
|
||||||
|
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||||
|
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
||||||
|
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
|
||||||
|
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||||
|
cross_frame_attention=cross_frame_attention,
|
||||||
|
device=self.device, vram_limit_level=vram_limit_level
|
||||||
|
)
|
||||||
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_posi
|
||||||
|
|
||||||
|
# DDIM and smoother
|
||||||
|
if smoother is not None and progress_id in smoother_progress_ids:
|
||||||
|
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
|
||||||
|
rendered_frames = self.decode_images(rendered_frames)
|
||||||
|
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
|
||||||
|
target_latents = self.encode_images(rendered_frames)
|
||||||
|
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
|
||||||
|
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||||
|
|
||||||
|
# UI
|
||||||
|
if progress_bar_st is not None:
|
||||||
|
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||||
|
|
||||||
|
# Decode image
|
||||||
|
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
|
||||||
|
output_frames = self.decode_images(latents)
|
||||||
|
|
||||||
|
# Post-process
|
||||||
|
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
|
||||||
|
output_frames = smoother(output_frames, original_frames=input_frames)
|
||||||
|
|
||||||
|
return output_frames
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SDVideoPipelineRunner:
|
||||||
|
def __init__(self, in_streamlit=False):
|
||||||
|
self.in_streamlit = in_streamlit
|
||||||
|
|
||||||
|
|
||||||
|
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
||||||
|
model_manager.load_textual_inversions(textual_inversion_folder)
|
||||||
|
model_manager.load_models(model_list, lora_alphas=lora_alphas)
|
||||||
|
pipe = SDVideoPipeline.from_model_manager(
|
||||||
|
model_manager,
|
||||||
|
[
|
||||||
|
ControlNetConfigUnit(
|
||||||
|
processor_id=unit["processor_id"],
|
||||||
|
model_path=unit["model_path"],
|
||||||
|
scale=unit["scale"]
|
||||||
|
) for unit in controlnet_units
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return model_manager, pipe
|
||||||
|
|
||||||
|
|
||||||
|
def load_smoother(self, model_manager, smoother_configs):
|
||||||
|
smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
|
||||||
|
return smoother
|
||||||
|
|
||||||
|
|
||||||
|
def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if self.in_streamlit:
|
||||||
|
import streamlit as st
|
||||||
|
progress_bar_st = st.progress(0.0)
|
||||||
|
output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
|
||||||
|
progress_bar_st.progress(1.0)
|
||||||
|
else:
|
||||||
|
output_video = pipe(**pipeline_inputs, smoother=smoother)
|
||||||
|
model_manager.to("cpu")
|
||||||
|
return output_video
|
||||||
|
|
||||||
|
|
||||||
|
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
||||||
|
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
||||||
|
if start_frame_id is None:
|
||||||
|
start_frame_id = 0
|
||||||
|
if end_frame_id is None:
|
||||||
|
end_frame_id = len(video)
|
||||||
|
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
||||||
|
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
||||||
|
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
||||||
|
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
||||||
|
if len(data["controlnet_frames"]) > 0:
|
||||||
|
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
||||||
|
return pipeline_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def save_output(self, video, output_folder, fps, config):
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
save_frames(video, os.path.join(output_folder, "frames"))
|
||||||
|
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
||||||
|
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
||||||
|
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
||||||
|
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
||||||
|
json.dump(config, file, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def run(self, config):
|
||||||
|
if self.in_streamlit:
|
||||||
|
import streamlit as st
|
||||||
|
if self.in_streamlit: st.markdown("Loading videos ...")
|
||||||
|
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
||||||
|
if self.in_streamlit: st.markdown("Loading videos ... done!")
|
||||||
|
if self.in_streamlit: st.markdown("Loading models ...")
|
||||||
|
model_manager, pipe = self.load_pipeline(**config["models"])
|
||||||
|
if self.in_streamlit: st.markdown("Loading models ... done!")
|
||||||
|
if "smoother_configs" in config:
|
||||||
|
if self.in_streamlit: st.markdown("Loading smoother ...")
|
||||||
|
smoother = self.load_smoother(model_manager, config["smoother_configs"])
|
||||||
|
if self.in_streamlit: st.markdown("Loading smoother ... done!")
|
||||||
|
else:
|
||||||
|
smoother = None
|
||||||
|
if self.in_streamlit: st.markdown("Synthesizing videos ...")
|
||||||
|
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
|
||||||
|
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
|
||||||
|
if self.in_streamlit: st.markdown("Saving videos ...")
|
||||||
|
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
||||||
|
if self.in_streamlit: st.markdown("Saving videos ... done!")
|
||||||
|
if self.in_streamlit: st.markdown("Finished!")
|
||||||
|
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
||||||
|
if self.in_streamlit: st.video(video_file.read())
|
||||||
175
diffsynth/pipelines/stable_diffusion_xl.py
Normal file
175
diffsynth/pipelines/stable_diffusion_xl.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||||
|
# TODO: SDXL ControlNet
|
||||||
|
from ..prompts import SDXLPrompter
|
||||||
|
from ..schedulers import EnhancedDDIMScheduler
|
||||||
|
from .dancer import lets_dance_xl
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLImagePipeline(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||||
|
super().__init__()
|
||||||
|
self.scheduler = EnhancedDDIMScheduler()
|
||||||
|
self.prompter = SDXLPrompter()
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
# models
|
||||||
|
self.text_encoder: SDXLTextEncoder = None
|
||||||
|
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||||
|
self.unet: SDXLUNet = None
|
||||||
|
self.vae_decoder: SDXLVAEDecoder = None
|
||||||
|
self.vae_encoder: SDXLVAEEncoder = None
|
||||||
|
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
||||||
|
self.ipadapter: SDXLIpAdapter = None
|
||||||
|
# TODO: SDXL ControlNet
|
||||||
|
|
||||||
|
def fetch_main_models(self, model_manager: ModelManager):
|
||||||
|
self.text_encoder = model_manager.text_encoder
|
||||||
|
self.text_encoder_2 = model_manager.text_encoder_2
|
||||||
|
self.unet = model_manager.unet
|
||||||
|
self.vae_decoder = model_manager.vae_decoder
|
||||||
|
self.vae_encoder = model_manager.vae_encoder
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
||||||
|
# TODO: SDXL ControlNet
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_ipadapter(self, model_manager: ModelManager):
|
||||||
|
if "ipadapter_xl" in model_manager.model:
|
||||||
|
self.ipadapter = model_manager.ipadapter_xl
|
||||||
|
if "ipadapter_xl_image_encoder" in model_manager.model:
|
||||||
|
self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_prompter(self, model_manager: ModelManager):
|
||||||
|
self.prompter.load_from_model_manager(model_manager)
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
|
||||||
|
pipe = SDXLImagePipeline(
|
||||||
|
device=model_manager.device,
|
||||||
|
torch_dtype=model_manager.torch_dtype,
|
||||||
|
)
|
||||||
|
pipe.fetch_main_models(model_manager)
|
||||||
|
pipe.fetch_prompter(model_manager)
|
||||||
|
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
||||||
|
pipe.fetch_ipadapter(model_manager)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(self, image):
|
||||||
|
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||||
|
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||||
|
image = image.cpu().permute(1, 2, 0).numpy()
|
||||||
|
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
clip_skip=1,
|
||||||
|
clip_skip_2=2,
|
||||||
|
input_image=None,
|
||||||
|
ipadapter_images=None,
|
||||||
|
ipadapter_scale=1.0,
|
||||||
|
controlnet_image=None,
|
||||||
|
denoising_strength=1.0,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
num_inference_steps=20,
|
||||||
|
tiled=False,
|
||||||
|
tile_size=64,
|
||||||
|
tile_stride=32,
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
progress_bar_st=None,
|
||||||
|
):
|
||||||
|
# Prepare scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||||
|
|
||||||
|
# Prepare latent tensors
|
||||||
|
if input_image is not None:
|
||||||
|
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||||
|
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||||
|
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||||
|
else:
|
||||||
|
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||||
|
|
||||||
|
# Encode prompts
|
||||||
|
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
|
||||||
|
self.text_encoder,
|
||||||
|
self.text_encoder_2,
|
||||||
|
prompt,
|
||||||
|
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||||
|
device=self.device,
|
||||||
|
positive=True,
|
||||||
|
)
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
||||||
|
self.text_encoder,
|
||||||
|
self.text_encoder_2,
|
||||||
|
negative_prompt,
|
||||||
|
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||||
|
device=self.device,
|
||||||
|
positive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare positional id
|
||||||
|
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||||
|
|
||||||
|
# IP-Adapter
|
||||||
|
if ipadapter_images is not None:
|
||||||
|
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||||
|
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
|
||||||
|
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
|
||||||
|
else:
|
||||||
|
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||||
|
|
||||||
|
# Classifier-free guidance
|
||||||
|
noise_pred_posi = lets_dance_xl(
|
||||||
|
self.unet,
|
||||||
|
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
|
||||||
|
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||||
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||||
|
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
|
||||||
|
)
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
noise_pred_nega = lets_dance_xl(
|
||||||
|
self.unet,
|
||||||
|
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
|
||||||
|
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
||||||
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||||
|
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
|
||||||
|
)
|
||||||
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_posi
|
||||||
|
|
||||||
|
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||||
|
|
||||||
|
if progress_bar_st is not None:
|
||||||
|
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||||
|
|
||||||
|
# Decode image
|
||||||
|
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
|
||||||
|
return image
|
||||||
190
diffsynth/pipelines/stable_diffusion_xl_video.py
Normal file
190
diffsynth/pipelines/stable_diffusion_xl_video.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLMotionModel
|
||||||
|
from .dancer import lets_dance_xl
|
||||||
|
# TODO: SDXL ControlNet
|
||||||
|
from ..prompts import SDXLPrompter
|
||||||
|
from ..schedulers import EnhancedDDIMScheduler
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLVideoPipeline(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
|
||||||
|
super().__init__()
|
||||||
|
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
|
||||||
|
self.prompter = SDXLPrompter()
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
# models
|
||||||
|
self.text_encoder: SDXLTextEncoder = None
|
||||||
|
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||||
|
self.unet: SDXLUNet = None
|
||||||
|
self.vae_decoder: SDXLVAEDecoder = None
|
||||||
|
self.vae_encoder: SDXLVAEEncoder = None
|
||||||
|
# TODO: SDXL ControlNet
|
||||||
|
self.motion_modules: SDXLMotionModel = None
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_main_models(self, model_manager: ModelManager):
|
||||||
|
self.text_encoder = model_manager.text_encoder
|
||||||
|
self.text_encoder_2 = model_manager.text_encoder_2
|
||||||
|
self.unet = model_manager.unet
|
||||||
|
self.vae_decoder = model_manager.vae_decoder
|
||||||
|
self.vae_encoder = model_manager.vae_encoder
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
||||||
|
# TODO: SDXL ControlNet
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_motion_modules(self, model_manager: ModelManager):
|
||||||
|
if "motion_modules_xl" in model_manager.model:
|
||||||
|
self.motion_modules = model_manager.motion_modules_xl
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_prompter(self, model_manager: ModelManager):
|
||||||
|
self.prompter.load_from_model_manager(model_manager)
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
|
||||||
|
pipe = SDXLVideoPipeline(
|
||||||
|
device=model_manager.device,
|
||||||
|
torch_dtype=model_manager.torch_dtype,
|
||||||
|
use_animatediff="motion_modules_xl" in model_manager.model
|
||||||
|
)
|
||||||
|
pipe.fetch_main_models(model_manager)
|
||||||
|
pipe.fetch_motion_modules(model_manager)
|
||||||
|
pipe.fetch_prompter(model_manager)
|
||||||
|
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(self, image):
|
||||||
|
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||||
|
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||||
|
image = image.cpu().permute(1, 2, 0).numpy()
|
||||||
|
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
||||||
|
images = [
|
||||||
|
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
for frame_id in range(latents.shape[0])
|
||||||
|
]
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
||||||
|
latents = []
|
||||||
|
for image in processed_images:
|
||||||
|
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
|
||||||
|
latents.append(latent)
|
||||||
|
latents = torch.concat(latents, dim=0)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=7.5,
|
||||||
|
clip_skip=1,
|
||||||
|
clip_skip_2=2,
|
||||||
|
num_frames=None,
|
||||||
|
input_frames=None,
|
||||||
|
controlnet_frames=None,
|
||||||
|
denoising_strength=1.0,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
num_inference_steps=20,
|
||||||
|
animatediff_batch_size = 16,
|
||||||
|
animatediff_stride = 8,
|
||||||
|
unet_batch_size = 1,
|
||||||
|
controlnet_batch_size = 1,
|
||||||
|
cross_frame_attention = False,
|
||||||
|
smoother=None,
|
||||||
|
smoother_progress_ids=[],
|
||||||
|
vram_limit_level=0,
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
progress_bar_st=None,
|
||||||
|
):
|
||||||
|
# Prepare scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||||
|
|
||||||
|
# Prepare latent tensors
|
||||||
|
if self.motion_modules is None:
|
||||||
|
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||||
|
else:
|
||||||
|
noise = torch.randn((num_frames, 4, height//8, width//8), device="cuda", dtype=self.torch_dtype)
|
||||||
|
if input_frames is None or denoising_strength == 1.0:
|
||||||
|
latents = noise
|
||||||
|
else:
|
||||||
|
latents = self.encode_images(input_frames)
|
||||||
|
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||||
|
|
||||||
|
# Encode prompts
|
||||||
|
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
|
||||||
|
self.text_encoder,
|
||||||
|
self.text_encoder_2,
|
||||||
|
prompt,
|
||||||
|
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||||
|
device=self.device,
|
||||||
|
positive=True,
|
||||||
|
)
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
||||||
|
self.text_encoder,
|
||||||
|
self.text_encoder_2,
|
||||||
|
negative_prompt,
|
||||||
|
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||||
|
device=self.device,
|
||||||
|
positive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare positional id
|
||||||
|
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||||
|
|
||||||
|
# Classifier-free guidance
|
||||||
|
noise_pred_posi = lets_dance_xl(
|
||||||
|
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
||||||
|
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||||
|
timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
|
||||||
|
cross_frame_attention=cross_frame_attention,
|
||||||
|
device=self.device, vram_limit_level=vram_limit_level
|
||||||
|
)
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
noise_pred_nega = lets_dance_xl(
|
||||||
|
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
||||||
|
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
||||||
|
timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
||||||
|
cross_frame_attention=cross_frame_attention,
|
||||||
|
device=self.device, vram_limit_level=vram_limit_level
|
||||||
|
)
|
||||||
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_posi
|
||||||
|
|
||||||
|
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||||
|
|
||||||
|
if progress_bar_st is not None:
|
||||||
|
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||||
|
|
||||||
|
# Decode image
|
||||||
|
image = self.decode_images(latents.to(torch.float32))
|
||||||
|
|
||||||
|
return image
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder
|
from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder
|
||||||
from ..schedulers import ContinuousODEScheduler
|
from ..schedulers import ContinuousODEScheduler
|
||||||
from .base import BasePipeline
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -9,11 +8,13 @@ from einops import rearrange, repeat
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SVDVideoPipeline(BasePipeline):
|
class SVDVideoPipeline(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
super().__init__()
|
||||||
self.scheduler = ContinuousODEScheduler()
|
self.scheduler = ContinuousODEScheduler()
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
# models
|
# models
|
||||||
self.image_encoder: SVDImageEncoder = None
|
self.image_encoder: SVDImageEncoder = None
|
||||||
self.unet: SVDUNet = None
|
self.unet: SVDUNet = None
|
||||||
@@ -21,23 +22,32 @@ class SVDVideoPipeline(BasePipeline):
|
|||||||
self.vae_decoder: SVDVAEDecoder = None
|
self.vae_decoder: SVDVAEDecoder = None
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager):
|
def fetch_main_models(self, model_manager: ModelManager):
|
||||||
self.image_encoder = model_manager.fetch_model("svd_image_encoder")
|
self.image_encoder = model_manager.image_encoder
|
||||||
self.unet = model_manager.fetch_model("svd_unet")
|
self.unet = model_manager.unet
|
||||||
self.vae_encoder = model_manager.fetch_model("svd_vae_encoder")
|
self.vae_encoder = model_manager.vae_encoder
|
||||||
self.vae_decoder = model_manager.fetch_model("svd_vae_decoder")
|
self.vae_decoder = model_manager.vae_decoder
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager: ModelManager, **kwargs):
|
def from_model_manager(model_manager: ModelManager, **kwargs):
|
||||||
pipe = SVDVideoPipeline(
|
pipe = SVDVideoPipeline(device=model_manager.device, torch_dtype=model_manager.torch_dtype)
|
||||||
device=model_manager.device,
|
pipe.fetch_main_models(model_manager)
|
||||||
torch_dtype=model_manager.torch_dtype
|
|
||||||
)
|
|
||||||
pipe.fetch_models(model_manager)
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(self, image):
|
||||||
|
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||||
|
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||||
|
image = image.cpu().permute(1, 2, 0).numpy()
|
||||||
|
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
def encode_image_with_clip(self, image):
|
def encode_image_with_clip(self, image):
|
||||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))
|
image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))
|
||||||
@@ -49,9 +59,9 @@ class SVDVideoPipeline(BasePipeline):
|
|||||||
return image_emb
|
return image_emb
|
||||||
|
|
||||||
|
|
||||||
def encode_image_with_vae(self, image, noise_aug_strength, seed=None):
|
def encode_image_with_vae(self, image, noise_aug_strength):
|
||||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
noise = self.generate_noise(image.shape, seed=seed, device=self.device, dtype=self.torch_dtype)
|
noise = torch.randn(image.shape, device="cpu", dtype=self.torch_dtype).to(self.device)
|
||||||
image = image + noise_aug_strength * noise
|
image = image + noise_aug_strength * noise
|
||||||
image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor
|
image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor
|
||||||
return image_emb
|
return image_emb
|
||||||
@@ -126,17 +136,14 @@ class SVDVideoPipeline(BasePipeline):
|
|||||||
num_inference_steps=20,
|
num_inference_steps=20,
|
||||||
post_normalize=True,
|
post_normalize=True,
|
||||||
contrast_enhance_scale=1.2,
|
contrast_enhance_scale=1.2,
|
||||||
seed=None,
|
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
progress_bar_st=None,
|
progress_bar_st=None,
|
||||||
):
|
):
|
||||||
height, width = self.check_resize_height_width(height, width)
|
|
||||||
|
|
||||||
# Prepare scheduler
|
# Prepare scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
||||||
|
|
||||||
# Prepare latent tensors
|
# Prepare latent tensors
|
||||||
noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).to(self.device)
|
||||||
if denoising_strength == 1.0:
|
if denoising_strength == 1.0:
|
||||||
latents = noise.clone()
|
latents = noise.clone()
|
||||||
else:
|
else:
|
||||||
@@ -150,7 +157,7 @@ class SVDVideoPipeline(BasePipeline):
|
|||||||
# Encode image
|
# Encode image
|
||||||
image_emb_clip_posi = self.encode_image_with_clip(input_image)
|
image_emb_clip_posi = self.encode_image_with_clip(input_image)
|
||||||
image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
|
image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
|
||||||
image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength, seed=seed), "B C H W -> (B T) C H W", T=num_frames)
|
image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength), "B C H W -> (B T) C H W", T=num_frames)
|
||||||
image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi)
|
image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi)
|
||||||
|
|
||||||
# Prepare classifier-free guidance
|
# Prepare classifier-free guidance
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
from .prompt_refiners import Translator, BeautifulPrompt, QwenPrompt
|
|
||||||
from .sd_prompter import SDPrompter
|
|
||||||
from .sdxl_prompter import SDXLPrompter
|
|
||||||
from .sd3_prompter import SD3Prompter
|
|
||||||
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
|
||||||
from .kolors_prompter import KolorsPrompter
|
|
||||||
from .flux_prompter import FluxPrompter
|
|
||||||
from .omost import OmostPromter
|
|
||||||
from .cog_prompter import CogPrompter
|
|
||||||
from .hunyuan_video_prompter import HunyuanVideoPrompter
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
from ..models.model_manager import ModelManager
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_long_prompt(tokenizer, prompt, max_length=None):
|
|
||||||
# Get model_max_length from self.tokenizer
|
|
||||||
length = tokenizer.model_max_length if max_length is None else max_length
|
|
||||||
|
|
||||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
|
||||||
tokenizer.model_max_length = 99999999
|
|
||||||
|
|
||||||
# Tokenize it!
|
|
||||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
|
||||||
|
|
||||||
# Determine the real length.
|
|
||||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
|
||||||
|
|
||||||
# Restore tokenizer.model_max_length
|
|
||||||
tokenizer.model_max_length = length
|
|
||||||
|
|
||||||
# Tokenize it again with fixed length.
|
|
||||||
input_ids = tokenizer(
|
|
||||||
prompt,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding="max_length",
|
|
||||||
max_length=max_length,
|
|
||||||
truncation=True
|
|
||||||
).input_ids
|
|
||||||
|
|
||||||
# Reshape input_ids to fit the text encoder.
|
|
||||||
num_sentence = input_ids.shape[1] // length
|
|
||||||
input_ids = input_ids.reshape((num_sentence, length))
|
|
||||||
|
|
||||||
return input_ids
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BasePrompter:
|
|
||||||
def __init__(self):
|
|
||||||
self.refiners = []
|
|
||||||
self.extenders = []
|
|
||||||
|
|
||||||
|
|
||||||
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
|
|
||||||
for refiner_class in refiner_classes:
|
|
||||||
refiner = refiner_class.from_model_manager(model_manager)
|
|
||||||
self.refiners.append(refiner)
|
|
||||||
|
|
||||||
def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
|
|
||||||
for extender_class in extender_classes:
|
|
||||||
extender = extender_class.from_model_manager(model_manager)
|
|
||||||
self.extenders.append(extender)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def process_prompt(self, prompt, positive=True):
|
|
||||||
if isinstance(prompt, list):
|
|
||||||
prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
|
|
||||||
else:
|
|
||||||
for refiner in self.refiners:
|
|
||||||
prompt = refiner(prompt, positive=positive)
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def extend_prompt(self, prompt:str, positive=True):
|
|
||||||
extended_prompt = dict(prompt=prompt)
|
|
||||||
for extender in self.extenders:
|
|
||||||
extended_prompt = extender(extended_prompt)
|
|
||||||
return extended_prompt
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
from .base_prompter import BasePrompter
|
|
||||||
from ..models.flux_text_encoder import FluxTextEncoder2
|
|
||||||
from transformers import T5TokenizerFast
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
class CogPrompter(BasePrompter):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer_path=None
|
|
||||||
):
|
|
||||||
if tokenizer_path is None:
|
|
||||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
|
||||||
tokenizer_path = os.path.join(base_path, "tokenizer_configs/cog/tokenizer")
|
|
||||||
super().__init__()
|
|
||||||
self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_path)
|
|
||||||
self.text_encoder: FluxTextEncoder2 = None
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, text_encoder: FluxTextEncoder2 = None):
|
|
||||||
self.text_encoder = text_encoder
|
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
|
|
||||||
input_ids = tokenizer(
|
|
||||||
prompt,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding="max_length",
|
|
||||||
max_length=max_length,
|
|
||||||
truncation=True,
|
|
||||||
).input_ids.to(device)
|
|
||||||
prompt_emb = text_encoder(input_ids)
|
|
||||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
|
||||||
|
|
||||||
return prompt_emb
|
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
positive=True,
|
|
||||||
device="cuda"
|
|
||||||
):
|
|
||||||
prompt = self.process_prompt(prompt, positive=positive)
|
|
||||||
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder, self.tokenizer, 226, device)
|
|
||||||
return prompt_emb
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
from .base_prompter import BasePrompter
|
|
||||||
from ..models.flux_text_encoder import FluxTextEncoder2
|
|
||||||
from ..models.sd3_text_encoder import SD3TextEncoder1
|
|
||||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
|
||||||
import os, torch
|
|
||||||
|
|
||||||
|
|
||||||
class FluxPrompter(BasePrompter):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer_1_path=None,
|
|
||||||
tokenizer_2_path=None
|
|
||||||
):
|
|
||||||
if tokenizer_1_path is None:
|
|
||||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
|
||||||
tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_1")
|
|
||||||
if tokenizer_2_path is None:
|
|
||||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
|
||||||
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_2")
|
|
||||||
super().__init__()
|
|
||||||
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
|
||||||
self.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_path)
|
|
||||||
self.text_encoder_1: SD3TextEncoder1 = None
|
|
||||||
self.text_encoder_2: FluxTextEncoder2 = None
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None):
|
|
||||||
self.text_encoder_1 = text_encoder_1
|
|
||||||
self.text_encoder_2 = text_encoder_2
|
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
|
|
||||||
input_ids = tokenizer(
|
|
||||||
prompt,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding="max_length",
|
|
||||||
max_length=max_length,
|
|
||||||
truncation=True
|
|
||||||
).input_ids.to(device)
|
|
||||||
pooled_prompt_emb, _ = text_encoder(input_ids)
|
|
||||||
return pooled_prompt_emb
|
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
|
|
||||||
input_ids = tokenizer(
|
|
||||||
prompt,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding="max_length",
|
|
||||||
max_length=max_length,
|
|
||||||
truncation=True,
|
|
||||||
).input_ids.to(device)
|
|
||||||
prompt_emb = text_encoder(input_ids)
|
|
||||||
return prompt_emb
|
|
||||||
|
|
||||||
|
|
||||||
def encode_prompt(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
positive=True,
|
|
||||||
device="cuda",
|
|
||||||
t5_sequence_length=512,
|
|
||||||
):
|
|
||||||
prompt = self.process_prompt(prompt, positive=positive)
|
|
||||||
|
|
||||||
# CLIP
|
|
||||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
|
|
||||||
|
|
||||||
# T5
|
|
||||||
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
|
|
||||||
|
|
||||||
# text_ids
|
|
||||||
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
|
|
||||||
|
|
||||||
return prompt_emb, pooled_prompt_emb, text_ids
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user