mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
100
README.md
100
README.md
@@ -2,7 +2,43 @@
|
||||
|
||||
## Introduction
|
||||
|
||||
DiffSynth is a new Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. This version is currently in its initial stage, supporting SD and SDXL architectures. In the future, we plan to develop more interesting features based on this new codebase.
|
||||
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
||||
|
||||
## Roadmap
|
||||
|
||||
* Aug 29, 2023. I propose DiffSynth, a video synthesis framework.
|
||||
* [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
|
||||
* The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
||||
* The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
||||
* Oct 1, 2023. I release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
|
||||
* The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
|
||||
* FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
|
||||
* The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
|
||||
* The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
|
||||
* A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
|
||||
* Since OLSS requires additional training, we don't implement it in this project.
|
||||
* Nov 15, 2023. I propose FastBlend, a powerful video deflickering algorithm.
|
||||
* The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
|
||||
* Demo videos are shown on Bilibili, including three tasks.
|
||||
* [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
|
||||
* [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
|
||||
* [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
||||
* The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
|
||||
* An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
|
||||
* Dec 8, 2023. I decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis.
|
||||
* Jan 29, 2024. I propose Diffutoon, a fantastic solution for toon shading.
|
||||
* [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/).
|
||||
* The source codes are released in this project.
|
||||
* The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
|
||||
* Until now, DiffSynth Studio has supported the following models:
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
||||
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
||||
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
||||
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
||||
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -30,72 +66,46 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4e
|
||||
|
||||
## Usage (in Python code)
|
||||
|
||||
### Example 1: Stable Diffusion
|
||||
The Python examples are in [`examples`](./examples/). We provide an overview here.
|
||||
|
||||
We can generate images with very high resolution. Please see `examples/sd_text_to_image.py` for more details.
|
||||
### Image Synthesis
|
||||
|
||||
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/)
|
||||
|
||||
|512*512|1024*1024|2048*2048|4096*4096|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
### Example 2: Stable Diffusion XL
|
||||
|
||||
Generate images with Stable Diffusion XL. Please see `examples/sdxl_text_to_image.py` for more details.
|
||||
|
||||
|1024*1024|2048*2048|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
### Example 3: Stable Diffusion XL Turbo
|
||||
### Toon Shading
|
||||
|
||||
Generate images with Stable Diffusion XL Turbo. You can see `examples/sdxl_turbo.py` for more details, but we highly recommend you to use it in the WebUI.
|
||||
|
||||
|"black car"|"red car"|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
### Example 4: Toon Shading (Diffutoon)
|
||||
|
||||
This example is implemented based on [Diffutoon](https://arxiv.org/abs/2401.16224). This approach is adept for rendering high-resoluton videos with rapid motion. You can easily modify the parameters in the config dict. See `examples/diffutoon_toon_shading.py`. We also provide [an example on Colab](https://colab.research.google.com/github/Artiprocher/DiffSynth-Studio/blob/main/examples/Diffutoon.ipynb).
|
||||
Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
||||
|
||||
### Example 5: Toon Shading with Editing Signals (Diffutoon)
|
||||
|
||||
This example is implemented based on [Diffutoon](https://arxiv.org/abs/2401.16224), supporting video editing signals. See `examples\diffutoon_toon_shading_with_editing_signals.py`. The editing feature is also supported in the [Colab example](https://colab.research.google.com/github/Artiprocher/DiffSynth-Studio/blob/main/examples/Diffutoon.ipynb).
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
|
||||
|
||||
### Example 6: Toon Shading (in native Python code)
|
||||
### Video Stylization
|
||||
|
||||
This example is provided for developers. If you don't want to use the config to manage parameters, you can see `examples/sd_toon_shading.py` to learn how to use it in native Python code.
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/607c199b-6140-410b-a111-3e4ffb01142c
|
||||
|
||||
### Example 7: Text to Video
|
||||
|
||||
Given a prompt, DiffSynth Studio can generate a video using a Stable Diffusion model and an AnimateDiff model. We can break the limitation of number of frames! See `examples/sd_text_to_video.py`.
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/8f556355-4079-4445-9b48-e9da77699437
|
||||
|
||||
### Example 8: Video Stylization
|
||||
|
||||
We provide an example for video stylization. In this pipeline, the rendered video is completely different from the original video, thus we need a powerful deflickering algorithm. We use FastBlend to implement the deflickering module. Please see `examples/sd_video_rerender.py` for more details.
|
||||
Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
### Example 9: Prompt Processing
|
||||
### Chinese Models
|
||||
|
||||
If you are not native English user, we provide translation service for you. Our prompter can translate other language to English and refine it using "BeautifulPrompt" models. Please see `examples/sd_prompt_refining.py` for more details.
|
||||
Use Hunyuan-DiT to generate images with Chinese prompts. We also support LoRA fine-tuning of this model. [`examples/hunyuan_dit`](./examples/hunyuan_dit/)
|
||||
|
||||
Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English.
|
||||
Prompt: 少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感
|
||||
|
||||
|seed=0|seed=1|seed=2|seed=3|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|1024x1024|2048x2048 (highres-fix)|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English. Then the [refining model](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd) will refine the translated prompt for better visual quality.
|
||||
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
|
||||
|
||||
|seed=0|seed=1|seed=2|seed=3|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|Without LoRA|With LoRA|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
7
configs/hunyuan_dit/tokenizer/special_tokens_map.json
Normal file
7
configs/hunyuan_dit/tokenizer/special_tokens_map.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"cls_token": "[CLS]",
|
||||
"mask_token": "[MASK]",
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
||||
16
configs/hunyuan_dit/tokenizer/tokenizer_config.json
Normal file
16
configs/hunyuan_dit/tokenizer/tokenizer_config.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"cls_token": "[CLS]",
|
||||
"do_basic_tokenize": true,
|
||||
"do_lower_case": true,
|
||||
"mask_token": "[MASK]",
|
||||
"name_or_path": "hfl/chinese-roberta-wwm-ext",
|
||||
"never_split": null,
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json",
|
||||
"strip_accents": null,
|
||||
"tokenize_chinese_chars": true,
|
||||
"tokenizer_class": "BertTokenizer",
|
||||
"unk_token": "[UNK]",
|
||||
"model_max_length": 77
|
||||
}
|
||||
47020
configs/hunyuan_dit/tokenizer/vocab.txt
Normal file
47020
configs/hunyuan_dit/tokenizer/vocab.txt
Normal file
File diff suppressed because it is too large
Load Diff
21128
configs/hunyuan_dit/tokenizer/vocab_org.txt
Normal file
21128
configs/hunyuan_dit/tokenizer/vocab_org.txt
Normal file
File diff suppressed because it is too large
Load Diff
28
configs/hunyuan_dit/tokenizer_t5/config.json
Normal file
28
configs/hunyuan_dit/tokenizer_t5/config.json
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"_name_or_path": "/home/patrick/t5/mt5-xl",
|
||||
"architectures": [
|
||||
"MT5ForConditionalGeneration"
|
||||
],
|
||||
"d_ff": 5120,
|
||||
"d_kv": 64,
|
||||
"d_model": 2048,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"feed_forward_proj": "gated-gelu",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "mt5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 32,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"tokenizer_class": "T5Tokenizer",
|
||||
"transformers_version": "4.10.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 250112
|
||||
}
|
||||
1
configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json
Normal file
1
configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json
Normal file
@@ -0,0 +1 @@
|
||||
{"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"}
|
||||
BIN
configs/hunyuan_dit/tokenizer_t5/spiece.model
Normal file
BIN
configs/hunyuan_dit/tokenizer_t5/spiece.model
Normal file
Binary file not shown.
1
configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json
Normal file
1
configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json
Normal file
@@ -0,0 +1 @@
|
||||
{"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 0, "additional_special_tokens": null, "special_tokens_map_file": "", "tokenizer_file": null, "name_or_path": "google/mt5-small", "model_max_length": 256, "legacy": true}
|
||||
@@ -24,6 +24,9 @@ from .svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||
@@ -83,6 +86,22 @@ class ModelManager:
|
||||
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
|
||||
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
|
||||
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_hunyuan_dit(self, state_dict):
|
||||
param_name = "final_layer.adaLN_modulation.1.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_diffusers_vae(self, state_dict):
|
||||
param_name = "quant_conv.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"image_encoder": SVDImageEncoder,
|
||||
@@ -223,6 +242,45 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit_clip_text_encoder"
|
||||
model = HunyuanDiTCLIPTextEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit_t5_text_encoder"
|
||||
model = HunyuanDiTT5TextEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit"
|
||||
model = HunyuanDiT()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_diffusers_vae(self, state_dict, file_path=""):
|
||||
# TODO: detect SD and SDXL
|
||||
component = "vae_encoder"
|
||||
model = SDXLVAEEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
component = "vae_decoder"
|
||||
model = SDXLVAEDecoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -276,6 +334,14 @@ class ModelManager:
|
||||
self.load_ipadapter_xl(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter_xl_image_encoder(state_dict):
|
||||
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
|
||||
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
|
||||
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit(state_dict):
|
||||
self.load_hunyuan_dit(state_dict, file_path=file_path)
|
||||
elif self.is_diffusers_vae(state_dict):
|
||||
self.load_diffusers_vae(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list, lora_alphas=[]):
|
||||
for file_path in file_path_list:
|
||||
|
||||
@@ -34,7 +34,7 @@ class Attention(torch.nn.Module):
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None):
|
||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
@@ -48,6 +48,9 @@ class Attention(torch.nn.Module):
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if qkv_preprocessor is not None:
|
||||
q, k, v = qkv_preprocessor(q, k, v)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
if ipadapter_kwargs is not None:
|
||||
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
||||
@@ -82,5 +85,5 @@ class Attention(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None):
|
||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs)
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
||||
451
diffsynth/models/hunyuan_dit.py
Normal file
451
diffsynth/models/hunyuan_dit.py
Normal file
@@ -0,0 +1,451 @@
|
||||
from .attention import Attention
|
||||
from .tiler import TileWorker
|
||||
from einops import repeat, rearrange
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
|
||||
super().__init__()
|
||||
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
|
||||
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
|
||||
self.rotary_emb_on_k = rotary_emb_on_k
|
||||
self.k_cache, self.v_cache = [], []
|
||||
|
||||
def reshape_for_broadcast(self, freqs_cis, x):
|
||||
ndim = x.ndim
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
|
||||
def rotate_half(self, x):
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
def apply_rotary_emb(self, xq, xk, freqs_cis):
|
||||
xk_out = None
|
||||
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
|
||||
if xk is not None:
|
||||
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
|
||||
return xq_out, xk_out
|
||||
|
||||
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
|
||||
# norm
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# RoPE
|
||||
if self.rotary_emb_on_k:
|
||||
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
|
||||
else:
|
||||
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
|
||||
|
||||
if to_cache:
|
||||
self.k_cache.append(k)
|
||||
self.v_cache.append(v)
|
||||
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
|
||||
k = torch.concat([k] + self.k_cache, dim=2)
|
||||
v = torch.concat([v] + self.v_cache, dim=2)
|
||||
self.k_cache, self.v_cache = [], []
|
||||
return q, k, v
|
||||
|
||||
|
||||
class FP32_Layernorm(torch.nn.LayerNorm):
|
||||
def forward(self, inputs):
|
||||
origin_dtype = inputs.dtype
|
||||
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
|
||||
|
||||
|
||||
class FP32_SiLU(torch.nn.SiLU):
|
||||
def forward(self, inputs):
|
||||
origin_dtype = inputs.dtype
|
||||
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
|
||||
|
||||
|
||||
class HunyuanDiTFinalLayer(torch.nn.Module):
|
||||
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
|
||||
super().__init__()
|
||||
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = torch.nn.Sequential(
|
||||
FP32_SiLU(),
|
||||
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def modulate(self, x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
def forward(self, hidden_states, condition_emb):
|
||||
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
|
||||
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
|
||||
hidden_states = self.linear(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanDiTBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim=1408,
|
||||
condition_dim=1408,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.3637,
|
||||
text_dim=1024,
|
||||
skip_connection=False
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
||||
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
|
||||
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
||||
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
|
||||
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
||||
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
|
||||
)
|
||||
if skip_connection:
|
||||
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
|
||||
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
|
||||
else:
|
||||
self.skip_norm, self.skip_linear = None, None
|
||||
|
||||
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
|
||||
# Long Skip Connection
|
||||
if self.skip_norm is not None and self.skip_linear is not None:
|
||||
hidden_states = torch.cat([hidden_states, residual], dim=-1)
|
||||
hidden_states = self.skip_norm(hidden_states)
|
||||
hidden_states = self.skip_linear(hidden_states)
|
||||
|
||||
# Self-Attention
|
||||
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
|
||||
attn_input = self.norm1(hidden_states) + shift_msa
|
||||
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
|
||||
|
||||
# Cross-Attention
|
||||
attn_input = self.norm3(hidden_states)
|
||||
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
|
||||
|
||||
# FFN Layer
|
||||
mlp_input = self.norm2(hidden_states)
|
||||
hidden_states = hidden_states + self.mlp(mlp_input)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttentionPool(torch.nn.Module):
|
||||
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
||||
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
||||
x, _ = torch.nn.functional.multi_head_attention_forward(
|
||||
query=x[:1], key=x, value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
)
|
||||
return x.squeeze(0)
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=(2, 2),
|
||||
in_chans=4,
|
||||
embed_dim=1408,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
else:
|
||||
embedding = repeat(t, "b -> b d", d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
class TimestepEmbedder(torch.nn.Module):
|
||||
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class HunyuanDiT(torch.nn.Module):
|
||||
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
|
||||
super().__init__()
|
||||
|
||||
# Embedders
|
||||
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
|
||||
self.t5_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
|
||||
FP32_SiLU(),
|
||||
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
|
||||
)
|
||||
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
|
||||
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
|
||||
self.patch_embedder = PatchEmbed(in_chans=in_channels)
|
||||
self.timestep_embedder = TimestepEmbedder()
|
||||
self.extra_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
|
||||
FP32_SiLU(),
|
||||
torch.nn.Linear(hidden_dim * 4, hidden_dim),
|
||||
)
|
||||
|
||||
# Transformer blocks
|
||||
self.num_layers_down = num_layers_down
|
||||
self.num_layers_up = num_layers_up
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
|
||||
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
|
||||
)
|
||||
|
||||
# Output layers
|
||||
self.final_layer = HunyuanDiTFinalLayer()
|
||||
self.out_channels = out_channels
|
||||
|
||||
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
|
||||
text_emb_mask = text_emb_mask.bool()
|
||||
text_emb_mask_t5 = text_emb_mask_t5.bool()
|
||||
text_emb_t5 = self.t5_embedder(text_emb_t5)
|
||||
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
|
||||
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
|
||||
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
|
||||
return text_emb
|
||||
|
||||
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
|
||||
# Text embedding
|
||||
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
|
||||
|
||||
# Timestep embedding
|
||||
timestep_emb = self.timestep_embedder(timestep)
|
||||
|
||||
# Size embedding
|
||||
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
|
||||
size_emb = size_emb.view(-1, 6 * 256)
|
||||
|
||||
# Style embedding
|
||||
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
|
||||
|
||||
# Concatenate all extra vectors
|
||||
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
|
||||
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
|
||||
|
||||
return condition_emb
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
|
||||
|
||||
def build_mask(self, data, is_bound):
|
||||
_, _, H, W = data.shape
|
||||
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
||||
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
||||
border_width = (H + W) // 4
|
||||
pad = torch.ones_like(h) * border_width
|
||||
mask = torch.stack([
|
||||
pad if is_bound[0] else h + 1,
|
||||
pad if is_bound[1] else H - h,
|
||||
pad if is_bound[2] else w + 1,
|
||||
pad if is_bound[3] else W - w
|
||||
]).min(dim=0).values
|
||||
mask = mask.clip(1, border_width)
|
||||
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
||||
mask = rearrange(mask, "H W -> 1 H W")
|
||||
return mask
|
||||
|
||||
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
|
||||
B, C, H, W = hidden_states.shape
|
||||
|
||||
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, tile_stride):
|
||||
for w in range(0, W, tile_stride):
|
||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
||||
continue
|
||||
h_, w_ = h + tile_size, w + tile_size
|
||||
if h_ > H: h, h_ = H - tile_size, H
|
||||
if w_ > W: w, w_ = W - tile_size, W
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
for hl, hr, wl, wr in tasks:
|
||||
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
|
||||
if residual is not None:
|
||||
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
|
||||
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
|
||||
else:
|
||||
residual_batch = None
|
||||
|
||||
# Forward
|
||||
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
|
||||
|
||||
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
|
||||
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
||||
weight[:, :, hl:hr, wl:wr] += mask
|
||||
values /= weight
|
||||
return values
|
||||
|
||||
def forward(
|
||||
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
to_cache=False,
|
||||
use_gradient_checkpointing=False,
|
||||
):
|
||||
# Embeddings
|
||||
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
|
||||
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
|
||||
|
||||
# Input
|
||||
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
|
||||
hidden_states = self.patch_embedder(hidden_states)
|
||||
|
||||
# Blocks
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
if tiled:
|
||||
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
|
||||
residuals = []
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
||||
hidden_states = self.tiled_block_forward(
|
||||
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
||||
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
|
||||
tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
if block_id < self.num_layers_down - 2:
|
||||
residuals.append(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
|
||||
else:
|
||||
residuals = []
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
|
||||
if block_id < self.num_layers_down - 2:
|
||||
residuals.append(hidden_states)
|
||||
|
||||
# Output
|
||||
hidden_states = self.final_layer(hidden_states, condition_emb)
|
||||
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
|
||||
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
return HunyuanDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
name_ = name
|
||||
name_ = name_.replace(".default_modulation.", ".modulation.")
|
||||
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
|
||||
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
|
||||
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
|
||||
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
|
||||
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
|
||||
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
|
||||
name_ = name_.replace(".q_proj.", ".to_q.")
|
||||
name_ = name_.replace(".out_proj.", ".to_out.")
|
||||
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
|
||||
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
|
||||
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
|
||||
name_ = name_.replace("pooler.", "t5_pooler.")
|
||||
name_ = name_.replace("x_embedder.", "patch_embedder.")
|
||||
name_ = name_.replace("t_embedder.", "timestep_embedder.")
|
||||
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
|
||||
name_ = name_.replace("style_embedder.weight", "style_embedder")
|
||||
if ".kv_proj." in name_:
|
||||
param_k = param[:param.shape[0]//2]
|
||||
param_v = param[param.shape[0]//2:]
|
||||
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
|
||||
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
|
||||
elif ".Wqkv." in name_:
|
||||
param_q = param[:param.shape[0]//3]
|
||||
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
|
||||
param_v = param[param.shape[0]//3*2:]
|
||||
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
|
||||
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
|
||||
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
|
||||
elif "style_embedder" in name_:
|
||||
state_dict_[name_] = param.squeeze()
|
||||
else:
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
161
diffsynth/models/hunyuan_dit_text_encoder.py
Normal file
161
diffsynth/models/hunyuan_dit_text_encoder.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTCLIPTextEncoder(BertModel):
|
||||
def __init__(self):
|
||||
config = BertConfig(
|
||||
_name_or_path = "",
|
||||
architectures = ["BertModel"],
|
||||
attention_probs_dropout_prob = 0.1,
|
||||
bos_token_id = 0,
|
||||
classifier_dropout = None,
|
||||
directionality = "bidi",
|
||||
eos_token_id = 2,
|
||||
hidden_act = "gelu",
|
||||
hidden_dropout_prob = 0.1,
|
||||
hidden_size = 1024,
|
||||
initializer_range = 0.02,
|
||||
intermediate_size = 4096,
|
||||
layer_norm_eps = 1e-12,
|
||||
max_position_embeddings = 512,
|
||||
model_type = "bert",
|
||||
num_attention_heads = 16,
|
||||
num_hidden_layers = 24,
|
||||
output_past = True,
|
||||
pad_token_id = 0,
|
||||
pooler_fc_size = 768,
|
||||
pooler_num_attention_heads = 12,
|
||||
pooler_num_fc_layers = 3,
|
||||
pooler_size_per_head = 128,
|
||||
pooler_type = "first_token_transform",
|
||||
position_embedding_type = "absolute",
|
||||
torch_dtype = "float32",
|
||||
transformers_version = "4.37.2",
|
||||
type_vocab_size = 2,
|
||||
use_cache = True,
|
||||
vocab_size = 47020
|
||||
)
|
||||
super().__init__(config, add_pooling_layer=False)
|
||||
self.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask, clip_skip=1):
|
||||
input_shape = input_ids.size()
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
past_key_values_length=0,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
all_hidden_states = encoder_outputs.hidden_states
|
||||
prompt_emb = all_hidden_states[-clip_skip]
|
||||
if clip_skip > 1:
|
||||
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
def state_dict_converter(self):
|
||||
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
||||
def __init__(self):
|
||||
config = T5Config(
|
||||
_name_or_path = "../HunyuanDiT/t2i/mt5",
|
||||
architectures = ["MT5ForConditionalGeneration"],
|
||||
classifier_dropout = 0.0,
|
||||
d_ff = 5120,
|
||||
d_kv = 64,
|
||||
d_model = 2048,
|
||||
decoder_start_token_id = 0,
|
||||
dense_act_fn = "gelu_new",
|
||||
dropout_rate = 0.1,
|
||||
eos_token_id = 1,
|
||||
feed_forward_proj = "gated-gelu",
|
||||
initializer_factor = 1.0,
|
||||
is_encoder_decoder = True,
|
||||
is_gated_act = True,
|
||||
layer_norm_epsilon = 1e-06,
|
||||
model_type = "t5",
|
||||
num_decoder_layers = 24,
|
||||
num_heads = 32,
|
||||
num_layers = 24,
|
||||
output_past = True,
|
||||
pad_token_id = 0,
|
||||
relative_attention_max_distance = 128,
|
||||
relative_attention_num_buckets = 32,
|
||||
tie_word_embeddings = False,
|
||||
tokenizer_class = "T5Tokenizer",
|
||||
transformers_version = "4.37.2",
|
||||
use_cache = True,
|
||||
vocab_size = 250112
|
||||
)
|
||||
super().__init__(config)
|
||||
self.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask, clip_skip=1):
|
||||
outputs = super().forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
prompt_emb = outputs.hidden_states[-clip_skip]
|
||||
if clip_skip > 1:
|
||||
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
def state_dict_converter(self):
|
||||
return HunyuanDiTT5TextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTCLIPTextEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
|
||||
class HunyuanDiTT5TextEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
|
||||
state_dict_["shared.weight"] = state_dict["shared.weight"]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -2,4 +2,5 @@ from .stable_diffusion import SDImagePipeline
|
||||
from .stable_diffusion_xl import SDXLImagePipeline
|
||||
from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
|
||||
from .stable_diffusion_xl_video import SDXLVideoPipeline
|
||||
from .stable_video_diffusion import SVDVideoPipeline
|
||||
from .stable_video_diffusion import SVDVideoPipeline
|
||||
from .hunyuan_dit import HunyuanDiTImagePipeline
|
||||
|
||||
298
diffsynth/pipelines/hunyuan_dit.py
Normal file
298
diffsynth/pipelines/hunyuan_dit.py
Normal file
@@ -0,0 +1,298 @@
|
||||
from ..models.hunyuan_dit import HunyuanDiT
|
||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from ..models import ModelManager
|
||||
from ..prompts import HunyuanDiTPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class ImageSizeManager:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def _to_tuple(self, x):
|
||||
if isinstance(x, int):
|
||||
return x, x
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def get_fill_resize_and_crop(self, src, tgt):
|
||||
th, tw = self._to_tuple(tgt)
|
||||
h, w = self._to_tuple(src)
|
||||
|
||||
tr = th / tw # base 分辨率
|
||||
r = h / w # 目标分辨率
|
||||
|
||||
# resize
|
||||
if r > tr:
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
def get_meshgrid(self, start, *args):
|
||||
if len(args) == 0:
|
||||
# start is grid_size
|
||||
num = self._to_tuple(start)
|
||||
start = (0, 0)
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
# start is start, args[0] is stop, step is 1
|
||||
start = self._to_tuple(start)
|
||||
stop = self._to_tuple(args[0])
|
||||
num = (stop[0] - start[0], stop[1] - start[1])
|
||||
elif len(args) == 2:
|
||||
# start is start, args[0] is stop, args[1] is num
|
||||
start = self._to_tuple(start) # 左上角 eg: 12,0
|
||||
stop = self._to_tuple(args[0]) # 右下角 eg: 20,32
|
||||
num = self._to_tuple(args[1]) # 目标大小 eg: 32,124
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
|
||||
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
|
||||
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0) # [2, W, H]
|
||||
return grid
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed(self, embed_dim, start, *args, use_real=True):
|
||||
grid = self.get_meshgrid(start, *args) # [2, H, w]
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
|
||||
pos_embed = self.get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_from_grid(self, embed_dim, grid, use_real=False):
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
emb_w = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
|
||||
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(self, dim: int, pos, theta: float = 10000.0, use_real=False):
|
||||
if isinstance(pos, int):
|
||||
pos = np.arange(pos)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
||||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def calc_rope(self, height, width):
|
||||
patch_size = 2
|
||||
head_size = 88
|
||||
th = height // 8 // patch_size
|
||||
tw = width // 8 // patch_size
|
||||
base_size = 512 // 8 // patch_size
|
||||
start, stop = self.get_fill_resize_and_crop((th, tw), base_size)
|
||||
sub_args = [start, stop, (th, tw)]
|
||||
rope = self.get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||
return rope
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTImagePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
|
||||
self.prompter = HunyuanDiTPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
self.image_size_manager = ImageSizeManager()
|
||||
# models
|
||||
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
|
||||
self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
|
||||
self.dit: HunyuanDiT = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.hunyuan_dit_clip_text_encoder
|
||||
self.text_encoder_t5 = model_manager.hunyuan_dit_t5_text_encoder
|
||||
self.dit = model_manager.hunyuan_dit
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager):
|
||||
pipe = HunyuanDiTImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def prepare_extra_input(self, height=1024, width=1024, tiled=False, tile_size=64, tile_stride=32, batch_size=1):
|
||||
if tiled:
|
||||
height, width = tile_size * 16, tile_size * 16
|
||||
image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
|
||||
freqs_cis_img = self.image_size_manager.calc_rope(height, width)
|
||||
image_meta_size = torch.stack([image_meta_size] * batch_size)
|
||||
return {
|
||||
"size_emb": image_meta_size,
|
||||
"freq_cis_img": (freqs_cis_img[0].to(dtype=self.torch_dtype, device=self.device), freqs_cis_img[1].to(dtype=self.torch_dtype, device=self.device)),
|
||||
"tiled": tiled,
|
||||
"tile_size": tile_size,
|
||||
"tile_stride": tile_stride
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
clip_skip_2=1,
|
||||
input_image=None,
|
||||
reference_images=[],
|
||||
reference_strengths=[0.4],
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = noise.clone()
|
||||
|
||||
# Prepare reference latents
|
||||
reference_latents = []
|
||||
for reference_image in reference_images:
|
||||
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
reference_latents.append(self.vae_encoder(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype))
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi, attention_mask_posi, prompt_emb_t5_posi, attention_mask_t5_posi = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_t5,
|
||||
prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip_2=clip_skip_2,
|
||||
positive=True,
|
||||
device=self.device
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega, attention_mask_nega, prompt_emb_t5_nega, attention_mask_t5_nega = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_t5,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip_2=clip_skip_2,
|
||||
positive=False,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
# Prepare positional id
|
||||
extra_input = self.prepare_extra_input(height, width, tiled, tile_size)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# In-context reference
|
||||
for reference_latents_, reference_strength in zip(reference_latents, reference_strengths):
|
||||
if progress_id < num_inference_steps * reference_strength:
|
||||
noisy_reference_latents = self.scheduler.add_noise(reference_latents_, noise, self.scheduler.timesteps[progress_id])
|
||||
self.dit(
|
||||
noisy_reference_latents,
|
||||
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
|
||||
timestep,
|
||||
**extra_input,
|
||||
to_cache=True
|
||||
)
|
||||
# Positive side
|
||||
noise_pred_posi = self.dit(
|
||||
latents,
|
||||
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
|
||||
timestep,
|
||||
**extra_input,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
# Negative side
|
||||
noise_pred_nega = self.dit(
|
||||
latents,
|
||||
prompt_emb_nega, prompt_emb_t5_nega, attention_mask_nega, attention_mask_t5_nega,
|
||||
timestep,
|
||||
**extra_input
|
||||
)
|
||||
# Classifier-free guidance
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
@@ -1,176 +1,3 @@
|
||||
from transformers import CLIPTokenizer, AutoTokenizer
|
||||
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, ModelManager
|
||||
import torch, os
|
||||
|
||||
|
||||
def tokenize_long_prompt(tokenizer, prompt):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = tokenizer.model_max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore tokenizer.model_max_length
|
||||
tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
class BeautifulPrompt:
|
||||
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
||||
|
||||
def __call__(self, raw_prompt):
|
||||
model_input = self.template.format(raw_prompt=raw_prompt)
|
||||
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=384,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1
|
||||
)
|
||||
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
|
||||
outputs[:, input_ids.size(1):],
|
||||
skip_special_tokens=True
|
||||
)[0].strip()
|
||||
return prompt
|
||||
|
||||
|
||||
class Translator:
|
||||
def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
|
||||
def __call__(self, prompt):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
|
||||
output_ids = self.model.generate(input_ids)
|
||||
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
||||
return prompt
|
||||
|
||||
|
||||
class Prompter:
|
||||
def __init__(self):
|
||||
self.tokenizer: CLIPTokenizer = None
|
||||
self.keyword_dict = {}
|
||||
self.translator: Translator = None
|
||||
self.beautiful_prompt: BeautifulPrompt = None
|
||||
|
||||
def load_textual_inversion(self, textual_inversion_dict):
|
||||
self.keyword_dict = {}
|
||||
additional_tokens = []
|
||||
for keyword in textual_inversion_dict:
|
||||
tokens, _ = textual_inversion_dict[keyword]
|
||||
additional_tokens += tokens
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
self.tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
def load_beautiful_prompt(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
|
||||
if model_folder.endswith("v2"):
|
||||
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
|
||||
def load_translator(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.translator = Translator(tokenizer_path=model_folder, model=model)
|
||||
|
||||
def load_from_model_manager(self, model_manager: ModelManager):
|
||||
self.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
if "translator" in model_manager.model:
|
||||
self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
def process_prompt(self, prompt, positive=True):
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
if positive and self.translator is not None:
|
||||
prompt = self.translator(prompt)
|
||||
print(f"Your prompt is translated: \"{prompt}\"")
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
return prompt
|
||||
|
||||
|
||||
class SDPrompter(Prompter):
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
|
||||
|
||||
class SDXLPrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path="configs/stable_diffusion/tokenizer",
|
||||
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: SDXLTextEncoder,
|
||||
text_encoder_2: SDXLTextEncoder2,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# 1
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
|
||||
# 2
|
||||
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
|
||||
add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
|
||||
|
||||
# Merge
|
||||
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
|
||||
|
||||
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
|
||||
add_text_embeds = add_text_embeds[0:1]
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
return add_text_embeds, prompt_emb
|
||||
from .sd_prompter import SDPrompter
|
||||
from .sdxl_prompter import SDXLPrompter
|
||||
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
||||
|
||||
56
diffsynth/prompts/hunyuan_dit_prompter.py
Normal file
56
diffsynth/prompts/hunyuan_dit_prompter.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from .utils import Prompter
|
||||
from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer
|
||||
import warnings
|
||||
|
||||
|
||||
class HunyuanDiTPrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path="configs/hunyuan_dit/tokenizer",
|
||||
tokenizer_t5_path="configs/hunyuan_dit/tokenizer_t5"
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path)
|
||||
|
||||
|
||||
def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
clip_skip=clip_skip
|
||||
)
|
||||
return prompt_embeds, attention_mask
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: BertModel,
|
||||
text_encoder_t5: T5EncoderModel,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=1,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# CLIP
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
|
||||
|
||||
# T5
|
||||
prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
|
||||
|
||||
return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5
|
||||
17
diffsynth/prompts/sd_prompter.py
Normal file
17
diffsynth/prompts/sd_prompter.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .utils import Prompter, tokenize_long_prompt
|
||||
from transformers import CLIPTokenizer
|
||||
from ..models import SDTextEncoder
|
||||
|
||||
|
||||
class SDPrompter(Prompter):
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
43
diffsynth/prompts/sdxl_prompter.py
Normal file
43
diffsynth/prompts/sdxl_prompter.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from .utils import Prompter, tokenize_long_prompt
|
||||
from transformers import CLIPTokenizer
|
||||
from ..models import SDXLTextEncoder, SDXLTextEncoder2
|
||||
import torch
|
||||
|
||||
|
||||
class SDXLPrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path="configs/stable_diffusion/tokenizer",
|
||||
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: SDXLTextEncoder,
|
||||
text_encoder_2: SDXLTextEncoder2,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# 1
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
|
||||
# 2
|
||||
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
|
||||
add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
|
||||
|
||||
# Merge
|
||||
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
|
||||
|
||||
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
|
||||
add_text_embeds = add_text_embeds[0:1]
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
return add_text_embeds, prompt_emb
|
||||
123
diffsynth/prompts/utils.py
Normal file
123
diffsynth/prompts/utils.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from transformers import CLIPTokenizer, AutoTokenizer
|
||||
from ..models import ModelManager
|
||||
import os
|
||||
|
||||
|
||||
def tokenize_long_prompt(tokenizer, prompt):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = tokenizer.model_max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore tokenizer.model_max_length
|
||||
tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
class BeautifulPrompt:
|
||||
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
||||
|
||||
def __call__(self, raw_prompt):
|
||||
model_input = self.template.format(raw_prompt=raw_prompt)
|
||||
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=384,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1
|
||||
)
|
||||
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
|
||||
outputs[:, input_ids.size(1):],
|
||||
skip_special_tokens=True
|
||||
)[0].strip()
|
||||
return prompt
|
||||
|
||||
|
||||
class Translator:
|
||||
def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
|
||||
def __call__(self, prompt):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
|
||||
output_ids = self.model.generate(input_ids)
|
||||
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
||||
return prompt
|
||||
|
||||
|
||||
class Prompter:
|
||||
def __init__(self):
|
||||
self.tokenizer: CLIPTokenizer = None
|
||||
self.keyword_dict = {}
|
||||
self.translator: Translator = None
|
||||
self.beautiful_prompt: BeautifulPrompt = None
|
||||
|
||||
def load_textual_inversion(self, textual_inversion_dict):
|
||||
self.keyword_dict = {}
|
||||
additional_tokens = []
|
||||
for keyword in textual_inversion_dict:
|
||||
tokens, _ = textual_inversion_dict[keyword]
|
||||
additional_tokens += tokens
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
self.tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
def load_beautiful_prompt(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
|
||||
if model_folder.endswith("v2"):
|
||||
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
|
||||
def load_translator(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.translator = Translator(tokenizer_path=model_folder, model=model)
|
||||
|
||||
def load_from_model_manager(self, model_manager: ModelManager):
|
||||
self.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
if "translator" in model_manager.model:
|
||||
self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
def process_prompt(self, prompt, positive=True):
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
if positive and self.translator is not None:
|
||||
prompt = self.translator(prompt)
|
||||
print(f"Your prompt is translated: \"{prompt}\"")
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
return prompt
|
||||
@@ -3,7 +3,7 @@ import torch, math
|
||||
|
||||
class EnhancedDDIMScheduler():
|
||||
|
||||
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"):
|
||||
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon"):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
if beta_schedule == "scaled_linear":
|
||||
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
|
||||
@@ -13,6 +13,7 @@ class EnhancedDDIMScheduler():
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist()
|
||||
self.set_timesteps(10)
|
||||
self.prediction_type = prediction_type
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
|
||||
@@ -28,9 +29,16 @@ class EnhancedDDIMScheduler():
|
||||
|
||||
|
||||
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
|
||||
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
|
||||
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
|
||||
prev_sample = sample * weight_x + model_output * weight_e
|
||||
if self.prediction_type == "epsilon":
|
||||
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
|
||||
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
|
||||
prev_sample = sample * weight_x + model_output * weight_e
|
||||
elif self.prediction_type == "v_prediction":
|
||||
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
|
||||
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
|
||||
prev_sample = sample * weight_x + model_output * weight_e
|
||||
else:
|
||||
raise NotImplementedError(f"{self.prediction_type} is not implemented")
|
||||
return prev_sample
|
||||
|
||||
|
||||
@@ -57,4 +65,9 @@ class EnhancedDDIMScheduler():
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
|
||||
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return target
|
||||
|
||||
@@ -18,3 +18,4 @@ dependencies:
|
||||
- imageio[ffmpeg]
|
||||
- safetensors
|
||||
- einops
|
||||
- sentencepiece
|
||||
|
||||
21
examples/Diffutoon/README.md
Normal file
21
examples/Diffutoon/README.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# Diffutoon
|
||||
|
||||
[Diffutoon](https://arxiv.org/abs/2401.16224) is a toon shading approach. This approach is adept for rendering high-resoluton videos with rapid motion.
|
||||
|
||||
## Example: Toon Shading (Diffutoon)
|
||||
|
||||
Directly render realistic videos in a flatten style. In this example, you can easily modify the parameters in the config dict. See [`diffutoon_toon_shading.py`](./diffutoon_toon_shading.py). We also provide [an example on Colab](https://colab.research.google.com/github/Artiprocher/DiffSynth-Studio/blob/main/examples/Diffutoon.ipynb).
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
||||
|
||||
## Example: Toon Shading with Editing Signals (Diffutoon)
|
||||
|
||||
This example supports video editing signals. See [`diffutoon_toon_shading_with_editing_signals.py`](./diffutoon_toon_shading_with_editing_signals.py). The editing feature is also supported in the [Colab example](https://colab.research.google.com/github/Artiprocher/DiffSynth-Studio/blob/main/examples/Diffutoon/Diffutoon.ipynb).
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
|
||||
|
||||
## Example: Toon Shading (in native Python code)
|
||||
|
||||
This example is provided for developers. If you don't want to use the config to manage parameters, you can see [`sd_toon_shading.py`](./sd_toon_shading.py) to learn how to use it in native Python code.
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/607c199b-6140-410b-a111-3e4ffb01142c
|
||||
3
examples/Ip-Adapter/README.md
Normal file
3
examples/Ip-Adapter/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# IP-Adapter
|
||||
|
||||
The features of IP-Adapter in DiffSynth Studio is not completed. Please wait for us.
|
||||
7
examples/diffsynth/README.md
Normal file
7
examples/diffsynth/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# DiffSynth
|
||||
|
||||
DiffSynth is the initial version of our video synthesis framework. In this framework, you can apply video deflickering algorithms to the latent space of diffusion models. You can refer to the [original repo](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth) for more details.
|
||||
|
||||
We provide an example for video stylization. In this pipeline, the rendered video is completely different from the original video, thus we need a powerful deflickering algorithm. We use FastBlend to implement the deflickering module. Please see [`sd_video_rerender.py`](./sd_video_rerender.py).
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
265
examples/hunyuan_dit/README.md
Normal file
265
examples/hunyuan_dit/README.md
Normal file
@@ -0,0 +1,265 @@
|
||||
# Hunyuan DiT
|
||||
|
||||
Hunyuan DiT is an image generation model based on DiT. We provide training and inference support for Hunyuan DiT.
|
||||
|
||||
## Download models
|
||||
|
||||
Four files will be used for constructing Hunyuan DiT. You can download them from [huggingface](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT) or [modelscope](https://www.modelscope.cn/models/modelscope/HunyuanDiT/summary).
|
||||
|
||||
```
|
||||
models/HunyuanDiT/
|
||||
├── Put Hunyuan DiT checkpoints here.txt
|
||||
└── t2i
|
||||
├── clip_text_encoder
|
||||
│ └── pytorch_model.bin
|
||||
├── model
|
||||
│ └── pytorch_model_ema.pt
|
||||
├── mt5
|
||||
│ └── pytorch_model.bin
|
||||
└── sdxl-vae-fp16-fix
|
||||
└── diffusion_pytorch_model.bin
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
### Text-to-image with highres-fix
|
||||
|
||||
The original resolution of Hunyuan DiT is 1024x1024. If you want to use larger resolutions, please use highres-fix.
|
||||
|
||||
Hunyuan DiT is also supported in our UI.
|
||||
|
||||
```python
|
||||
from diffsynth import ModelManager, HunyuanDiTImagePipeline
|
||||
import torch
|
||||
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin",
|
||||
"models/HunyuanDiT/t2i/mt5/pytorch_model.bin",
|
||||
"models/HunyuanDiT/t2i/model/pytorch_model_ema.pt",
|
||||
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"
|
||||
])
|
||||
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# Enjoy!
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt="少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感",
|
||||
negative_prompt="错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,",
|
||||
num_inference_steps=50, height=1024, width=1024,
|
||||
)
|
||||
image.save("image_1024.png")
|
||||
|
||||
# Highres fix
|
||||
image = pipe(
|
||||
prompt="少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感",
|
||||
negative_prompt="错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,",
|
||||
input_image=image.resize((2048, 2048)),
|
||||
num_inference_steps=50, height=2048, width=2048,
|
||||
cfg_scale=3.0, denoising_strength=0.5, tiled=True,
|
||||
)
|
||||
image.save("image_2048.png")
|
||||
```
|
||||
|
||||
Prompt: 少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感
|
||||
|
||||
|1024x1024|2048x2048 (highres-fix)|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
### In-context reference (experimental)
|
||||
|
||||
This feature is similar to the "reference-only" mode in ControlNets. By extending the self-attention layer, the content in the reference image can be retained in the new image. Any number of reference images are supported, and the influence from each reference image can be controled by independent `reference_strengths` parameters.
|
||||
|
||||
```python
|
||||
from diffsynth import ModelManager, HunyuanDiTImagePipeline
|
||||
import torch
|
||||
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin",
|
||||
"models/HunyuanDiT/t2i/mt5/pytorch_model.bin",
|
||||
"models/HunyuanDiT/t2i/model/pytorch_model_ema.pt",
|
||||
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"
|
||||
])
|
||||
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# Generate an image as reference
|
||||
torch.manual_seed(0)
|
||||
reference_image = pipe(
|
||||
prompt="梵高,星空,油画,明亮",
|
||||
negative_prompt="",
|
||||
num_inference_steps=50, height=1024, width=1024,
|
||||
)
|
||||
reference_image.save("image_reference.png")
|
||||
|
||||
# Generate a new image with reference
|
||||
image = pipe(
|
||||
prompt="层峦叠嶂的山脉,郁郁葱葱的森林,皎洁明亮的月光,夜色下的自然美景",
|
||||
negative_prompt="",
|
||||
reference_images=[reference_image], reference_strengths=[0.4],
|
||||
num_inference_steps=50, height=1024, width=1024,
|
||||
)
|
||||
image.save("image_with_reference.png")
|
||||
```
|
||||
|
||||
Prompt: 层峦叠嶂的山脉,郁郁葱葱的森林,皎洁明亮的月光,夜色下的自然美景
|
||||
|
||||
|Reference image|Generated new image|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
## Train
|
||||
|
||||
### Install training dependency
|
||||
|
||||
```
|
||||
pip install peft lightning pandas torchvision
|
||||
```
|
||||
|
||||
### Prepare your dataset
|
||||
|
||||
We provide an example dataset [here](https://modelscope.cn/datasets/buptwq/lora-stable-diffusion-finetune/files). You need to manage the training images as follows:
|
||||
|
||||
```
|
||||
data/dog/
|
||||
└── train
|
||||
├── 00.jpg
|
||||
├── 01.jpg
|
||||
├── 02.jpg
|
||||
├── 03.jpg
|
||||
├── 04.jpg
|
||||
└── metadata.csv
|
||||
```
|
||||
|
||||
`metadata.csv`:
|
||||
|
||||
```
|
||||
file_name,text
|
||||
00.jpg,一只小狗
|
||||
01.jpg,一只小狗
|
||||
02.jpg,一只小狗
|
||||
03.jpg,一只小狗
|
||||
04.jpg,一只小狗
|
||||
```
|
||||
|
||||
### Train a LoRA model
|
||||
|
||||
We provide a training script `train_hunyuan_dit_lora.py`. Before you run this training script, please copy it to the root directory of this project.
|
||||
|
||||
If GPU memory >= 24GB, we recommmand to use the following settings.
|
||||
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES="0" python train_hunyuan_dit_lora.py \
|
||||
--pretrained_path models/HunyuanDiT/t2i \
|
||||
--dataset_path data/dog \
|
||||
--output_path ./models \
|
||||
--max_epochs 1 \
|
||||
--center_crop
|
||||
```
|
||||
|
||||
If 12GB <= GPU memory <= 24GB, we recommand to enable gradient checkpointing.
|
||||
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES="0" python train_hunyuan_dit_lora.py \
|
||||
--pretrained_path models/HunyuanDiT/t2i \
|
||||
--dataset_path data/dog \
|
||||
--output_path ./models \
|
||||
--max_epochs 1 \
|
||||
--center_crop \
|
||||
--use_gradient_checkpointing
|
||||
```
|
||||
|
||||
Optional arguments:
|
||||
```
|
||||
-h, --help show this help message and exit
|
||||
--pretrained_path PRETRAINED_PATH
|
||||
Path to pretrained model. For example, `./HunyuanDiT/t2i`.
|
||||
--dataset_path DATASET_PATH
|
||||
The path of the Dataset.
|
||||
--output_path OUTPUT_PATH
|
||||
Path to save the model.
|
||||
--steps_per_epoch STEPS_PER_EPOCH
|
||||
Number of steps per epoch.
|
||||
--height HEIGHT Image height.
|
||||
--width WIDTH Image width.
|
||||
--center_crop Whether to center crop the input images to the resolution. If not set, the images will be randomly cropped. The images will be resized to the resolution first before cropping.
|
||||
--random_flip Whether to randomly flip images horizontally
|
||||
--batch_size BATCH_SIZE
|
||||
Batch size (per device) for the training dataloader.
|
||||
--dataloader_num_workers DATALOADER_NUM_WORKERS
|
||||
Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
|
||||
--precision {32,16,16-mixed}
|
||||
Training precision
|
||||
--learning_rate LEARNING_RATE
|
||||
Learning rate.
|
||||
--lora_rank LORA_RANK
|
||||
The dimension of the LoRA update matrices.
|
||||
--lora_alpha LORA_ALPHA
|
||||
The weight of the LoRA update matrices.
|
||||
--use_gradient_checkpointing
|
||||
Whether to use gradient checkpointing.
|
||||
--accumulate_grad_batches ACCUMULATE_GRAD_BATCHES
|
||||
The number of batches in gradient accumulation.
|
||||
--training_strategy {auto,deepspeed_stage_1,deepspeed_stage_2,deepspeed_stage_3}
|
||||
Training strategy
|
||||
--max_epochs MAX_EPOCHS
|
||||
Number of epochs.
|
||||
```
|
||||
|
||||
### Inference with your own LoRA model
|
||||
|
||||
After training, you can use your own LoRA model to generate new images. Here are some examples.
|
||||
|
||||
```python
|
||||
from diffsynth import ModelManager, HunyuanDiTImagePipeline
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
import torch
|
||||
|
||||
|
||||
def load_lora(dit, lora_rank, lora_alpha, lora_path):
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out"],
|
||||
)
|
||||
dit = inject_adapter_in_model(lora_config, dit)
|
||||
state_dict = torch.load(lora_path, map_location="cpu")
|
||||
dit.load_state_dict(state_dict, strict=False)
|
||||
return dit
|
||||
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin",
|
||||
"models/HunyuanDiT/t2i/mt5/pytorch_model.bin",
|
||||
"models/HunyuanDiT/t2i/model/pytorch_model_ema.pt",
|
||||
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"
|
||||
])
|
||||
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# Generate an image with lora
|
||||
pipe.dit = load_lora(
|
||||
pipe.dit, lora_rank=4, lora_alpha=4.0,
|
||||
lora_path="path/to/your/lora/model/lightning_logs/version_x/checkpoints/epoch=x-step=xxx.ckpt"
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt="一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉",
|
||||
negative_prompt="",
|
||||
num_inference_steps=50, height=1024, width=1024,
|
||||
)
|
||||
image.save("image_with_lora.png")
|
||||
```
|
||||
|
||||
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
|
||||
|
||||
|Without LoRA|With LoRA|
|
||||
|-|-|
|
||||
|||
|
||||
298
examples/hunyuan_dit/train_hunyuan_dit_lora.py
Normal file
298
examples/hunyuan_dit/train_hunyuan_dit_lora.py
Normal file
@@ -0,0 +1,298 @@
|
||||
from diffsynth import ModelManager, HunyuanDiTImagePipeline
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import lightning as pl
|
||||
import pandas as pd
|
||||
import torch, os, argparse
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||
|
||||
|
||||
|
||||
class TextImageDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
||||
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
self.image_processor = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
||||
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
text = self.text[data_id]
|
||||
image = Image.open(self.path[data_id]).convert("RGB")
|
||||
image = self.image_processor(image)
|
||||
return {"text": text, "image": image}
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
|
||||
|
||||
|
||||
class LightningModel(pl.LightningModule):
|
||||
def __init__(self, torch_dtype=torch.float16, learning_rate=1e-4, pretrained_weights=[], lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True):
|
||||
super().__init__()
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
|
||||
model_manager.load_models(pretrained_weights)
|
||||
self.pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# Freeze parameters
|
||||
self.pipe.text_encoder.requires_grad_(False)
|
||||
self.pipe.text_encoder_t5.requires_grad_(False)
|
||||
self.pipe.dit.requires_grad_(False)
|
||||
self.pipe.vae_decoder.requires_grad_(False)
|
||||
self.pipe.vae_encoder.requires_grad_(False)
|
||||
self.pipe.text_encoder.eval()
|
||||
self.pipe.text_encoder_t5.eval()
|
||||
self.pipe.dit.train()
|
||||
self.pipe.vae_decoder.eval()
|
||||
self.pipe.vae_encoder.eval()
|
||||
|
||||
# Add LoRA to DiT
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out"],
|
||||
)
|
||||
self.pipe.dit = inject_adapter_in_model(lora_config, self.pipe.dit)
|
||||
for param in self.pipe.dit.parameters():
|
||||
# Upcast LoRA parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# Set other parameters
|
||||
self.learning_rate = learning_rate
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Data
|
||||
text, image = batch["text"], batch["image"]
|
||||
|
||||
# Prepare input parameters
|
||||
self.pipe.device = self.device
|
||||
prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5 = self.pipe.prompter.encode_prompt(
|
||||
self.pipe.text_encoder, self.pipe.text_encoder_t5, text, positive=True, device=self.device
|
||||
)
|
||||
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||
noise = torch.randn_like(latents)
|
||||
timestep = torch.randint(0, 1000, (1,), device=self.device)
|
||||
extra_input = self.pipe.prepare_extra_input(image.shape[-2], image.shape[-1], batch_size=latents.shape[0])
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||
|
||||
# Compute loss
|
||||
noise_pred = self.pipe.dit(
|
||||
noisy_latents,
|
||||
prompt_emb, prompt_emb_t5, attention_mask, attention_mask_t5,
|
||||
timestep,
|
||||
**extra_input,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred, training_target)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
return loss
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.dit.parameters())
|
||||
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
checkpoint.clear()
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.dit.named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
state_dict = self.pipe.dit.state_dict()
|
||||
for name, param in state_dict.items():
|
||||
if name in trainable_param_names:
|
||||
checkpoint[name] = param
|
||||
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model. For example, `./HunyuanDiT/t2i`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The path of the Dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./",
|
||||
help="Path to save the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_epoch",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps per epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Image height.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Image width.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_flip",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="16-mixed",
|
||||
choices=["32", "16", "16-mixed"],
|
||||
help="Training precision",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Learning rate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_rank",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The dimension of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=float,
|
||||
default=4.0,
|
||||
help="The weight of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gradient_checkpointing",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accumulate_grad_batches",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of batches in gradient accumulation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training_strategy",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
|
||||
help="Training strategy",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_epochs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of epochs.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# args
|
||||
args = parse_args()
|
||||
|
||||
# dataset and data loader
|
||||
dataset = TextImageDataset(
|
||||
args.dataset_path,
|
||||
steps_per_epoch=args.steps_per_epoch * args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
center_crop=args.center_crop,
|
||||
random_flip=args.random_flip
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
|
||||
# model
|
||||
model = LightningModel(
|
||||
pretrained_weights=[
|
||||
os.path.join(args.pretrained_path, "clip_text_encoder/pytorch_model.bin"),
|
||||
os.path.join(args.pretrained_path, "mt5/pytorch_model.bin"),
|
||||
os.path.join(args.pretrained_path, "model/pytorch_model_ema.pt"),
|
||||
os.path.join(args.pretrained_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
||||
],
|
||||
torch_dtype=torch.float32 if args.precision == "32" else torch.float16,
|
||||
learning_rate=args.learning_rate,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing
|
||||
)
|
||||
|
||||
# train
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=args.max_epochs,
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
precision=args.precision,
|
||||
strategy=args.training_strategy,
|
||||
default_root_dir=args.output_path,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
|
||||
)
|
||||
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||
43
examples/image_synthesis/README.md
Normal file
43
examples/image_synthesis/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Image Synthesis
|
||||
|
||||
Image synthesis is the base feature of DiffSynth Studio.
|
||||
|
||||
### Example: Stable Diffusion
|
||||
|
||||
We can generate images with very high resolution. Please see [`sd_text_to_image.py`](./sd_text_to_image.py) for more details.
|
||||
|
||||
|512*512|1024*1024|2048*2048|4096*4096|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
### Example: Stable Diffusion XL
|
||||
|
||||
Generate images with Stable Diffusion XL. Please see [`sdxl_text_to_image.py`](./sdxl_text_to_image.py) for more details.
|
||||
|
||||
|1024*1024|2048*2048|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
### Example: Stable Diffusion XL Turbo
|
||||
|
||||
Generate images with Stable Diffusion XL Turbo. You can see [`sdxl_turbo.py`](./sdxl_turbo.py) for more details, but we highly recommend you to use it in the WebUI.
|
||||
|
||||
|"black car"|"red car"|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
### Example: Prompt Processing
|
||||
|
||||
If you are not native English user, we provide translation service for you. Our prompter can translate other language to English and refine it using "BeautifulPrompt" models. Please see [`sd_prompt_refining.py`](./sd_prompt_refining.py) for more details.
|
||||
|
||||
Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English.
|
||||
|
||||
|seed=0|seed=1|seed=2|seed=3|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English. Then the [refining model](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd) will refine the translated prompt for better visual quality.
|
||||
|
||||
|seed=0|seed=1|seed=2|seed=3|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
9
examples/video_synthesis/README.md
Normal file
9
examples/video_synthesis/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Text to Video
|
||||
|
||||
In DiffSynth Studio, we can use AnimateDiff and SVD to generate videos. However, these models usually generate terrible contents. We do not recommend users to use these models, until a more powerful video model emerges.
|
||||
|
||||
### Example 7: Text to Video
|
||||
|
||||
Generate a video using a Stable Diffusion model and an AnimateDiff model. We can break the limitation of number of frames! See [sd_text_to_video.py](./sd_text_to_video.py).
|
||||
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/8f556355-4079-4445-9b48-e9da77699437
|
||||
@@ -5,7 +5,7 @@ import streamlit as st
|
||||
st.set_page_config(layout="wide")
|
||||
from streamlit_drawable_canvas import st_canvas
|
||||
from diffsynth.models import ModelManager
|
||||
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline
|
||||
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, HunyuanDiTImagePipeline
|
||||
from diffsynth.data.video import crop_and_resize
|
||||
|
||||
|
||||
@@ -30,14 +30,23 @@ config = {
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
}
|
||||
}
|
||||
},
|
||||
"HunyuanDiT": {
|
||||
"model_folder": "models/HunyuanDiT",
|
||||
"pipeline_class": HunyuanDiTImagePipeline,
|
||||
"fixed_parameters": {
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_model_list(model_type):
|
||||
folder = config[model_type]["model_folder"]
|
||||
file_list = os.listdir(folder)
|
||||
file_list = [i for i in file_list if i.endswith(".safetensors")]
|
||||
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
||||
if model_type == "HunyuanDiT":
|
||||
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
||||
file_list = sorted(file_list)
|
||||
return file_list
|
||||
|
||||
@@ -53,7 +62,15 @@ def release_model():
|
||||
|
||||
def load_model(model_type, model_path):
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_model(model_path)
|
||||
if model_type == "HunyuanDiT":
|
||||
model_manager.load_models([
|
||||
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
|
||||
os.path.join(model_path, "mt5/pytorch_model.bin"),
|
||||
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
||||
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
||||
])
|
||||
else:
|
||||
model_manager.load_model(model_path)
|
||||
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||
st.session_state.loaded_model_path = model_path
|
||||
st.session_state.model_manager = model_manager
|
||||
@@ -109,7 +126,7 @@ column_input, column_output = st.columns(2)
|
||||
with st.sidebar:
|
||||
# Select a model
|
||||
with st.expander("Model", expanded=True):
|
||||
model_type = st.selectbox("Model type", ["Stable Diffusion", "Stable Diffusion XL", "Stable Diffusion XL Turbo"])
|
||||
model_type = st.selectbox("Model type", [model_type_ for model_type_ in config])
|
||||
fixed_parameters = config[model_type]["fixed_parameters"]
|
||||
model_path_list = ["None"] + load_model_list(model_type)
|
||||
model_path = st.selectbox("Model path", model_path_list)
|
||||
@@ -124,13 +141,17 @@ with st.sidebar:
|
||||
model_path = os.path.join(config[model_type]["model_folder"], model_path)
|
||||
if st.session_state.get("loaded_model_path", "") != model_path:
|
||||
# The loaded model is not the selected model. Reload it.
|
||||
st.markdown(f"Using model at {model_path}.")
|
||||
st.markdown(f"Loading model at {model_path}.")
|
||||
st.markdown("Please wait a moment...")
|
||||
release_model()
|
||||
model_manager, pipeline = load_model(model_type, model_path)
|
||||
st.markdown("Done.")
|
||||
else:
|
||||
# The loaded model is not the selected model. Fetch it from `st.session_state`.
|
||||
st.markdown(f"Using model at {model_path}.")
|
||||
st.markdown(f"Loading model at {model_path}.")
|
||||
st.markdown("Please wait a moment...")
|
||||
model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
|
||||
st.markdown("Done.")
|
||||
|
||||
# Show parameters
|
||||
with st.expander("Prompt", expanded=True):
|
||||
|
||||
Reference in New Issue
Block a user