diff --git a/DiffSynth_Studio.py b/DiffSynth_Studio.py new file mode 100644 index 0000000..855e5a5 --- /dev/null +++ b/DiffSynth_Studio.py @@ -0,0 +1,15 @@ +# Set web page format +import streamlit as st +st.set_page_config(layout="wide") +# Diasble virtual VRAM on windows system +import torch +torch.cuda.set_per_process_memory_fraction(0.999, 0) + + +st.markdown(""" +# DiffSynth Studio + +[Source Code](https://github.com/Artiprocher/DiffSynth-Studio) + +Welcome to DiffSynth Studio. +""") diff --git a/README.md b/README.md index 106b5b1..66c57cc 100644 --- a/README.md +++ b/README.md @@ -2,76 +2,56 @@ ## Introduction -This branch supports video-to-video translation and is still under development. +DiffSynth is a new Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. This version is currently in its initial stage, supporting SD and SDXL architectures. In the future, we plan to develop more interesting features based on this new codebase. ## Installation +Create Python environment: + ``` conda env create -f environment.yml ``` -## Usage +Enter the Python environment: -### Example 1: Toon Shading - -https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/53532f0e-39b1-4791-b920-c975d52ec24a - -You can download the models as follows: - -* `models/stable_diffusion/flat2DAnimerge_v45Sharp.safetensors`: [link](https://civitai.com/api/download/models/266360?type=Model&format=SafeTensor&size=pruned&fp=fp16) -* `models/AnimateDiff/mm_sd_v15_v2.ckpt`: [link](https://huggingface.co/guoyww/animatediff/resolve/main/mm_sd_v15_v2.ckpt) -* `models/ControlNet/control_v11p_sd15_lineart.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_lineart.pth) -* `models/ControlNet/control_v11f1e_sd15_tile.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth) -* `models/Annotators/sk_model.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model.pth) -* `models/Annotators/sk_model2.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model2.pth) - -```python -from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames -import torch - - -# Load models -model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") -model_manager.load_textual_inversions("models/textual_inversion") -model_manager.load_models([ - "models/stable_diffusion/flat2DAnimerge_v45Sharp.safetensors", - "models/AnimateDiff/mm_sd_v15_v2.ckpt", - "models/ControlNet/control_v11p_sd15_lineart.pth", - "models/ControlNet/control_v11f1e_sd15_tile.pth", -]) -pipe = SDVideoPipeline.from_model_manager( - model_manager, - [ - ControlNetConfigUnit( - processor_id="lineart", - model_path="models/ControlNet/control_v11p_sd15_lineart.pth", - scale=1.0 - ), - ControlNetConfigUnit( - processor_id="tile", - model_path="models/ControlNet/control_v11f1e_sd15_tile.pth", - scale=0.5 - ), - ] -) - -# Load video -video = VideoData(video_file="data/66dance/raw.mp4", height=1536, width=1536) -input_video = [video[i] for i in range(40*60, 40*60+16)] - -# Toon shading -torch.manual_seed(0) -output_video = pipe( - prompt="best quality, perfect anime illustration, light, a girl is dancing, smile, solo", - negative_prompt="verybadimagenegative_v1.3", - cfg_scale=5, clip_skip=2, - controlnet_frames=input_video, num_frames=16, - num_inference_steps=10, height=1536, width=1536, - vram_limit_level=0, -) - -# Save images and video -save_frames(output_video, "data/text2video/frames") -save_video(output_video, "data/text2video/video.mp4", fps=16) +``` +conda activate DiffSynthStudio ``` +## Usage (in WebUI) + +``` +python -m streamlit run Diffsynth_Studio.py +``` + +## Usage (in Python code) + +### Example 1: Stable Diffusion + +We can generate images with very high resolution. Please see `examples/sd_text_to_image.py` for more details. + +|512*512|1024*1024|2048*2048|4096*4096| +|-|-|-|-| +|![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)| + +### Example 2: Stable Diffusion XL + +Generate images with Stable Diffusion XL. Please see `examples/sdxl_text_to_image.py` for more details. + +|1024*1024|2048*2048| +|-|-| +|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)| + +### Example 3: Stable Diffusion XL Turbo + +Generate images with Stable Diffusion XL Turbo. You can see `examples/sdxl_turbo.py` for more details, but we highly recommend you to use it in the WebUI. + +|"black car"|"red car"| +|-|-| +|![black_car](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/7fbfd803-68d4-44f3-8713-8c925fec47d0)|![black_car_to_red_car](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/aaf886e4-c33c-4fd8-98e2-29eef117ba00)| + +### Example 4: Toon Shading + +A very interesting example. Please see `examples/sd_toon_shading.py` for more details. + +https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/53532f0e-39b1-4791-b920-c975d52ec24a diff --git a/diffsynth/controlnets/controlnet_unit.py b/diffsynth/controlnets/controlnet_unit.py index d279bd7..3a0ad97 100644 --- a/diffsynth/controlnets/controlnet_unit.py +++ b/diffsynth/controlnets/controlnet_unit.py @@ -36,10 +36,17 @@ class MultiControlNetManager: ], dim=0) return processed_image - def __call__(self, sample, timestep, encoder_hidden_states, conditionings): + def __call__( + self, + sample, timestep, encoder_hidden_states, conditionings, + tiled=False, tile_size=64, tile_stride=32 + ): res_stack = None for conditioning, model, scale in zip(conditionings, self.models, self.scales): - res_stack_ = model(sample, timestep, encoder_hidden_states, conditioning) + res_stack_ = model( + sample, timestep, encoder_hidden_states, conditioning, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + ) res_stack_ = [res * scale for res in res_stack_] if res_stack is None: res_stack = res_stack_ diff --git a/diffsynth/controlnets/processors.py b/diffsynth/controlnets/processors.py index 36bceb3..d6ea121 100644 --- a/diffsynth/controlnets/processors.py +++ b/diffsynth/controlnets/processors.py @@ -12,7 +12,7 @@ Processor_id: TypeAlias = Literal[ ] class Annotator: - def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=512): + def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None): if processor_id == "canny": self.processor = CannyDetector() elif processor_id == "depth": @@ -44,7 +44,8 @@ class Annotator: else: kwargs = {} if self.processor is not None: - image = self.processor(image, detect_resolution=self.detect_resolution, image_resolution=min(width, height), **kwargs) + detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height) + image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs) image = image.resize((width, height)) return image diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index 41fecb5..dcbe1d7 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -15,6 +15,8 @@ from .sd_controlnet import SDControlNet from .sd_motion import SDMotionModel +from transformers import AutoModelForCausalLM + class ModelManager: def __init__(self, torch_dtype=torch.float16, device="cuda"): @@ -24,12 +26,19 @@ class ModelManager: self.model_path = {} self.textual_inversion_dict = {} + def is_beautiful_prompt(self, state_dict): + param_name = "transformer.h.9.self_attention.query_key_value.weight" + return param_name in state_dict + def is_stabe_diffusion_xl(self, state_dict): param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight" return param_name in state_dict def is_stable_diffusion(self, state_dict): - return True + if self.is_stabe_diffusion_xl(state_dict): + return False + param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight" + return param_name in state_dict def is_controlnet(self, state_dict): param_name = "control_model.time_embed.0.weight" @@ -74,7 +83,6 @@ class ModelManager: "unet": SDXLUNet, "vae_decoder": SDXLVAEDecoder, "vae_encoder": SDXLVAEEncoder, - "refiner": SDXLUNet, } if components is None: components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"] @@ -109,6 +117,15 @@ class ModelManager: self.model[component] = model self.model_path[component] = file_path + def load_beautiful_prompt(self, state_dict, file_path=""): + component = "beautiful_prompt" + model_folder = os.path.dirname(file_path) + model = AutoModelForCausalLM.from_pretrained( + model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype + ).to(self.device).eval() + self.model[component] = model + self.model_path[component] = file_path + def search_for_embeddings(self, state_dict): embeddings = [] for k in state_dict: @@ -144,6 +161,8 @@ class ModelManager: self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path) elif self.is_stable_diffusion(state_dict): self.load_stable_diffusion(state_dict, components=components, file_path=file_path) + elif self.is_beautiful_prompt(state_dict): + self.load_beautiful_prompt(state_dict, file_path=file_path) def load_models(self, file_path_list): for file_path in file_path_list: diff --git a/diffsynth/models/attention.py b/diffsynth/models/attention.py index 1a2c110..5961c11 100644 --- a/diffsynth/models/attention.py +++ b/diffsynth/models/attention.py @@ -41,7 +41,7 @@ class Attention(torch.nn.Module): v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) - hidden_states = hidden_states.transpose(1, 2).view(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) hidden_states = self.to_out(hidden_states) diff --git a/diffsynth/models/sd_controlnet.py b/diffsynth/models/sd_controlnet.py index f43a6de..2b6f57e 100644 --- a/diffsynth/models/sd_controlnet.py +++ b/diffsynth/models/sd_controlnet.py @@ -1,5 +1,6 @@ import torch -from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock, DownSampler, UpSampler +from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler +from .tiler import TileWorker class ControlNetConditioningLayer(torch.nn.Module): @@ -92,20 +93,37 @@ class SDControlNet(torch.nn.Module): self.global_pool = global_pool - def forward(self, sample, timestep, encoder_hidden_states, conditioning): + def forward( + self, + sample, timestep, encoder_hidden_states, conditioning, + tiled=False, tile_size=64, tile_stride=32, + ): # 1. time time_emb = self.time_proj(timestep[None]).to(sample.dtype) time_emb = self.time_embedding(time_emb) time_emb = time_emb.repeat(sample.shape[0], 1) # 2. pre-process + height, width = sample.shape[2], sample.shape[3] hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning) text_emb = encoder_hidden_states res_stack = [hidden_states] # 3. blocks for i, block in enumerate(self.blocks): - hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + if tiled and not isinstance(block, PushBlock): + _, _, inter_height, _ = hidden_states.shape + resize_scale = inter_height / height + hidden_states = TileWorker().tiled_forward( + lambda x: block(x, time_emb, text_emb, res_stack)[0], + hidden_states, + int(tile_size * resize_scale), + int(tile_stride * resize_scale), + tile_device=hidden_states.device, + tile_dtype=hidden_states.dtype + ) + else: + hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack) # 4. ControlNet blocks controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)] diff --git a/diffsynth/models/sd_unet.py b/diffsynth/models/sd_unet.py index 1fd1f02..dcdcb53 100644 --- a/diffsynth/models/sd_unet.py +++ b/diffsynth/models/sd_unet.py @@ -279,31 +279,19 @@ class SDUNet(torch.nn.Module): self.conv_act = torch.nn.SiLU() self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1) - def forward(self, sample, timestep, encoder_hidden_states, tiled=False, tile_size=64, tile_stride=8, additional_res_stack=None, **kwargs): + def forward(self, sample, timestep, encoder_hidden_states, **kwargs): # 1. time time_emb = self.time_proj(timestep[None]).to(sample.dtype) time_emb = self.time_embedding(time_emb) - time_emb = time_emb.repeat(sample.shape[0], 1) # 2. pre-process - height, width = sample.shape[2], sample.shape[3] hidden_states = self.conv_in(sample) text_emb = encoder_hidden_states res_stack = [hidden_states] # 3. blocks for i, block in enumerate(self.blocks): - if additional_res_stack is not None and i==31: - hidden_states += additional_res_stack.pop() - res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)] - additional_res_stack = None - if tiled: - hidden_states, time_emb, text_emb, res_stack = self.tiled_inference( - block, hidden_states, time_emb, text_emb, res_stack, - height, width, tile_size, tile_stride - ) - else: - hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) # 4. output hidden_states = self.conv_norm_out(hidden_states) @@ -312,23 +300,6 @@ class SDUNet(torch.nn.Module): return hidden_states - def tiled_inference(self, block, hidden_states, time_emb, text_emb, res_stack, height, width, tile_size, tile_stride): - if block.__class__.__name__ in ["ResnetBlock", "AttentionBlock", "DownSampler", "UpSampler"]: - batch_size, inter_channel, inter_height, inter_width = hidden_states.shape - resize_scale = inter_height / height - - hidden_states = Tiler()( - lambda x: block(x, time_emb, text_emb, res_stack)[0], - hidden_states, - int(tile_size * resize_scale), - int(tile_stride * resize_scale), - inter_device=hidden_states.device, - inter_dtype=hidden_states.dtype - ) - else: - hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) - return hidden_states, time_emb, text_emb, res_stack - def state_dict_converter(self): return SDUNetStateDictConverter() diff --git a/diffsynth/models/sd_vae_decoder.py b/diffsynth/models/sd_vae_decoder.py index e62270c..0fba92d 100644 --- a/diffsynth/models/sd_vae_decoder.py +++ b/diffsynth/models/sd_vae_decoder.py @@ -1,7 +1,7 @@ import torch from .attention import Attention from .sd_unet import ResnetBlock, UpSampler -from .tiler import Tiler +from .tiler import TileWorker class VAEAttentionBlock(torch.nn.Module): @@ -79,11 +79,13 @@ class SDVAEDecoder(torch.nn.Module): self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1) def tiled_forward(self, sample, tile_size=64, tile_stride=32): - hidden_states = Tiler()( + hidden_states = TileWorker().tiled_forward( lambda x: self.forward(x), sample, tile_size, - tile_stride + tile_stride, + tile_device=sample.device, + tile_dtype=sample.dtype ) return hidden_states diff --git a/diffsynth/models/sd_vae_encoder.py b/diffsynth/models/sd_vae_encoder.py index 4968e6c..7e284be 100644 --- a/diffsynth/models/sd_vae_encoder.py +++ b/diffsynth/models/sd_vae_encoder.py @@ -1,7 +1,7 @@ import torch from .sd_unet import ResnetBlock, DownSampler from .sd_vae_decoder import VAEAttentionBlock -from .tiler import Tiler +from .tiler import TileWorker class SDVAEEncoder(torch.nn.Module): @@ -38,11 +38,13 @@ class SDVAEEncoder(torch.nn.Module): self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1) def tiled_forward(self, sample, tile_size=64, tile_stride=32): - hidden_states = Tiler()( + hidden_states = TileWorker().tiled_forward( lambda x: self.forward(x), sample, tile_size, - tile_stride + tile_stride, + tile_device=sample.device, + tile_dtype=sample.dtype ) return hidden_states diff --git a/diffsynth/models/tiler.py b/diffsynth/models/tiler.py index 906b4f6..30db58f 100644 --- a/diffsynth/models/tiler.py +++ b/diffsynth/models/tiler.py @@ -1,4 +1,5 @@ import torch +from einops import rearrange, repeat class Tiler(torch.nn.Module): @@ -70,6 +71,106 @@ class Tiler(torch.nn.Module): return x - - \ No newline at end of file +class TileWorker: + def __init__(self): + pass + + + def mask(self, height, width, border_width): + # Create a mask with shape (height, width). + # The centre area is filled with 1, and the border line is filled with values in range (0, 1]. + x = torch.arange(height).repeat(width, 1).T + y = torch.arange(width).repeat(height, 1) + mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values + mask = (mask / border_width).clip(0, 1) + return mask + + + def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype): + # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num) + batch_size, channel, _, _ = model_input.shape + model_input = model_input.to(device=tile_device, dtype=tile_dtype) + unfold_operator = torch.nn.Unfold( + kernel_size=(tile_size, tile_size), + stride=(tile_stride, tile_stride) + ) + model_input = unfold_operator(model_input) + model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1)) + + return model_input + + + def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype): + # Call y=forward_fn(x) for each tile + tile_num = model_input.shape[-1] + model_output_stack = [] + + for tile_id in range(0, tile_num, tile_batch_size): + + # process input + tile_id_ = min(tile_id + tile_batch_size, tile_num) + x = model_input[:, :, :, :, tile_id: tile_id_] + x = x.to(device=inference_device, dtype=inference_dtype) + x = rearrange(x, "b c h w n -> (n b) c h w") + + # process output + y = forward_fn(x) + y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id) + y = y.to(device=tile_device, dtype=tile_dtype) + model_output_stack.append(y) + + model_output = torch.concat(model_output_stack, dim=-1) + return model_output + + + def io_scale(self, model_output, tile_size): + # Determine the size modification happend in forward_fn + # We only consider the same scale on height and width. + io_scale = model_output.shape[2] / tile_size + return io_scale + + + def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype): + # The reversed function of tile + mask = self.mask(tile_size, tile_size, border_width) + mask = mask.to(device=tile_device, dtype=tile_dtype) + mask = rearrange(mask, "h w -> 1 1 h w 1") + model_output = model_output * mask + + fold_operator = torch.nn.Fold( + output_size=(height, width), + kernel_size=(tile_size, tile_size), + stride=(tile_stride, tile_stride) + ) + mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1]) + model_output = rearrange(model_output, "b c h w n -> b (c h w) n") + model_output = fold_operator(model_output) / fold_operator(mask) + + return model_output + + + def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None): + # Prepare + inference_device, inference_dtype = model_input.device, model_input.dtype + height, width = model_input.shape[2], model_input.shape[3] + border_width = int(tile_stride*0.5) if border_width is None else border_width + + # tile + model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype) + + # inference + model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype) + + # resize + io_scale = self.io_scale(model_output, tile_size) + height, width = int(height*io_scale), int(width*io_scale) + tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale) + border_width = int(border_width*io_scale) + + # untile + model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype) + + # Done! + model_output = model_output.to(device=inference_device, dtype=inference_dtype) + return model_output \ No newline at end of file diff --git a/diffsynth/pipelines/__init__.py b/diffsynth/pipelines/__init__.py index 9fe28d1..5074e75 100644 --- a/diffsynth/pipelines/__init__.py +++ b/diffsynth/pipelines/__init__.py @@ -1,3 +1,3 @@ -from .stable_diffusion import SDPipeline -from .stable_diffusion_xl import SDXLPipeline +from .stable_diffusion import SDImagePipeline +from .stable_diffusion_xl import SDXLImagePipeline from .stable_diffusion_video import SDVideoPipeline diff --git a/diffsynth/pipelines/dancer.py b/diffsynth/pipelines/dancer.py new file mode 100644 index 0000000..800876b --- /dev/null +++ b/diffsynth/pipelines/dancer.py @@ -0,0 +1,113 @@ +import torch +from ..models import SDUNet, SDMotionModel +from ..models.sd_unet import PushBlock, PopBlock +from ..models.tiler import TileWorker +from ..controlnets import MultiControlNetManager + + +def lets_dance( + unet: SDUNet, + motion_modules: SDMotionModel = None, + controlnet: MultiControlNetManager = None, + sample = None, + timestep = None, + encoder_hidden_states = None, + controlnet_frames = None, + unet_batch_size = 1, + controlnet_batch_size = 1, + tiled=False, + tile_size=64, + tile_stride=32, + device = "cuda", + vram_limit_level = 0, +): + # 1. ControlNet + # This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride. + # I leave it here because I intend to do something interesting on the ControlNets. + controlnet_insert_block_id = 30 + if controlnet is not None and controlnet_frames is not None: + res_stacks = [] + # process controlnet frames with batch + for batch_id in range(0, sample.shape[0], controlnet_batch_size): + batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0]) + res_stack = controlnet( + sample[batch_id: batch_id_], + timestep, + encoder_hidden_states[batch_id: batch_id_], + controlnet_frames[:, batch_id: batch_id_], + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + ) + if vram_limit_level >= 1: + res_stack = [res.cpu() for res in res_stack] + res_stacks.append(res_stack) + # concat the residual + additional_res_stack = [] + for i in range(len(res_stacks[0])): + res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0) + additional_res_stack.append(res) + else: + additional_res_stack = None + + # 2. time + time_emb = unet.time_proj(timestep[None]).to(sample.dtype) + time_emb = unet.time_embedding(time_emb) + + # 3. pre-process + height, width = sample.shape[2], sample.shape[3] + hidden_states = unet.conv_in(sample) + text_emb = encoder_hidden_states + res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states] + + # 4. blocks + for block_id, block in enumerate(unet.blocks): + # 4.1 UNet + if isinstance(block, PushBlock): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + if vram_limit_level>=1: + res_stack[-1] = res_stack[-1].cpu() + elif isinstance(block, PopBlock): + if vram_limit_level>=1: + res_stack[-1] = res_stack[-1].to(device) + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + else: + hidden_states_input = hidden_states + hidden_states_output = [] + for batch_id in range(0, sample.shape[0], unet_batch_size): + batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) + if tiled: + _, _, inter_height, _ = hidden_states.shape + resize_scale = inter_height / height + hidden_states = TileWorker().tiled_forward( + lambda x: block(x, time_emb, text_emb[batch_id: batch_id_], res_stack)[0], + hidden_states_input[batch_id: batch_id_], + int(tile_size * resize_scale), + int(tile_stride * resize_scale), + tile_device=hidden_states.device, + tile_dtype=hidden_states.dtype + ) + else: + hidden_states, _, _, _ = block(hidden_states_input[batch_id: batch_id_], time_emb, text_emb[batch_id: batch_id_], res_stack) + hidden_states_output.append(hidden_states) + hidden_states = torch.concat(hidden_states_output, dim=0) + # 4.2 AnimateDiff + if motion_modules is not None: + if block_id in motion_modules.call_block_id: + motion_module_id = motion_modules.call_block_id[block_id] + hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( + hidden_states, time_emb, text_emb, res_stack, + batch_size=1 + ) + # 4.3 ControlNet + if block_id == controlnet_insert_block_id and additional_res_stack is not None: + hidden_states += additional_res_stack.pop().to(device) + if vram_limit_level>=1: + res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)] + else: + res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)] + + # 5. output + hidden_states = unet.conv_norm_out(hidden_states) + hidden_states = unet.conv_act(hidden_states) + hidden_states = unet.conv_out(hidden_states) + + return hidden_states diff --git a/diffsynth/pipelines/stable_diffusion.py b/diffsynth/pipelines/stable_diffusion.py index 9d5d962..d82d480 100644 --- a/diffsynth/pipelines/stable_diffusion.py +++ b/diffsynth/pipelines/stable_diffusion.py @@ -1,14 +1,16 @@ from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder -from ..controlnets.controlnet_unit import MultiControlNetManager +from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator from ..prompts import SDPrompter from ..schedulers import EnhancedDDIMScheduler +from .dancer import lets_dance +from typing import List import torch from tqdm import tqdm from PIL import Image import numpy as np -class SDPipeline(torch.nn.Module): +class SDImagePipeline(torch.nn.Module): def __init__(self, device="cuda", torch_dtype=torch.float16): super().__init__() @@ -23,6 +25,7 @@ class SDPipeline(torch.nn.Module): self.vae_encoder: SDVAEEncoder = None self.controlnet: MultiControlNetManager = None + def fetch_main_models(self, model_manager: ModelManager): self.text_encoder = model_manager.text_encoder self.unet = model_manager.unet @@ -31,13 +34,48 @@ class SDPipeline(torch.nn.Module): # load textual inversion self.prompter.load_textual_inversion(model_manager.textual_inversion_dict) - def fetch_controlnet_models(self, controlnet_units=[]): + + def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): + controlnet_units = [] + for config in controlnet_config_units: + controlnet_unit = ControlNetUnit( + Annotator(config.processor_id), + model_manager.get_model_with_model_path(config.model_path), + config.scale + ) + controlnet_units.append(controlnet_unit) self.controlnet = MultiControlNetManager(controlnet_units) + + + def fetch_beautiful_prompt(self, model_manager: ModelManager): + if "beautiful_prompt" in model_manager.model: + self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"]) + + + @staticmethod + def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): + pipe = SDImagePipeline( + device=model_manager.device, + torch_dtype=model_manager.torch_dtype, + ) + pipe.fetch_main_models(model_manager) + pipe.fetch_beautiful_prompt(model_manager) + pipe.fetch_controlnet_models(model_manager, controlnet_config_units) + return pipe + def preprocess_image(self, image): image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) return image + + def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): + image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = image.cpu().permute(1, 2, 0).numpy() + image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + return image + + @torch.no_grad() def __call__( self, @@ -45,7 +83,7 @@ class SDPipeline(torch.nn.Module): negative_prompt="", cfg_scale=7.5, clip_skip=1, - init_image=None, + input_image=None, controlnet_image=None, denoising_strength=1.0, height=512, @@ -57,48 +95,43 @@ class SDPipeline(torch.nn.Module): progress_bar_cmd=tqdm, progress_bar_st=None, ): - # Encode prompts - prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device) - prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device) - # Prepare scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength) # Prepare latent tensors - if init_image is not None: - image = self.preprocess_image(init_image).to(device=self.device, dtype=self.torch_dtype) + if input_image is not None: + image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) else: latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) + # Encode prompts + prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True) + prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False) + # Prepare ControlNets if controlnet_image is not None: controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype) + controlnet_image = controlnet_image.unsqueeze(1) # Denoise for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = torch.IntTensor((timestep,))[0].to(self.device) - # ControlNet - if controlnet_image is not None: - additional_res_stack_posi = self.controlnet(latents, timestep, prompt_emb_posi, controlnet_image) - additional_res_stack_nega = self.controlnet(latents, timestep, prompt_emb_nega, controlnet_image) - else: - additional_res_stack_posi = None - additional_res_stack_nega = None - # Classifier-free guidance - noise_pred_posi = self.unet( - latents, timestep, prompt_emb_posi, - additional_res_stack=additional_res_stack_posi, - tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + noise_pred_posi = lets_dance( + self.unet, motion_modules=None, controlnet=self.controlnet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + device=self.device, vram_limit_level=0 ) - noise_pred_nega = self.unet( - latents, timestep, prompt_emb_nega, - additional_res_stack=additional_res_stack_nega, - tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + noise_pred_nega = lets_dance( + self.unet, motion_modules=None, controlnet=self.controlnet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + device=self.device, vram_limit_level=0 ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) @@ -110,8 +143,6 @@ class SDPipeline(torch.nn.Module): progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) # Decode image - image = self.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] - image = image.cpu().permute(1, 2, 0).numpy() - image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return image diff --git a/diffsynth/pipelines/stable_diffusion_video.py b/diffsynth/pipelines/stable_diffusion_video.py index 64e4783..1aa2f3e 100644 --- a/diffsynth/pipelines/stable_diffusion_video.py +++ b/diffsynth/pipelines/stable_diffusion_video.py @@ -1,8 +1,8 @@ from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel -from ..models.sd_unet import PushBlock, PopBlock from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator from ..prompts import SDPrompter from ..schedulers import EnhancedDDIMScheduler +from .dancer import lets_dance from typing import List import torch from tqdm import tqdm @@ -10,97 +10,6 @@ from PIL import Image import numpy as np -def lets_dance( - unet: SDUNet, - motion_modules: SDMotionModel = None, - controlnet: MultiControlNetManager = None, - sample = None, - timestep = None, - encoder_hidden_states = None, - controlnet_frames = None, - unet_batch_size = 1, - controlnet_batch_size = 1, - device = "cuda", - vram_limit_level = 0, -): - # 1. ControlNet - # This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride. - # I leave it here because I intend to do something interesting on the ControlNets. - controlnet_insert_block_id = 30 - if controlnet is not None and controlnet_frames is not None: - res_stacks = [] - # process controlnet frames with batch - for batch_id in range(0, sample.shape[0], controlnet_batch_size): - batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0]) - res_stack = controlnet( - sample[batch_id: batch_id_], - timestep, - encoder_hidden_states[batch_id: batch_id_], - controlnet_frames[:, batch_id: batch_id_] - ) - if vram_limit_level >= 1: - res_stack = [res.cpu() for res in res_stack] - res_stacks.append(res_stack) - # concat the residual - additional_res_stack = [] - for i in range(len(res_stacks[0])): - res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0) - additional_res_stack.append(res) - else: - additional_res_stack = None - - # 2. time - time_emb = unet.time_proj(timestep[None]).to(sample.dtype) - time_emb = unet.time_embedding(time_emb) - - # 3. pre-process - hidden_states = unet.conv_in(sample) - text_emb = encoder_hidden_states - res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states] - - # 4. blocks - for block_id, block in enumerate(unet.blocks): - # 4.1 UNet - if isinstance(block, PushBlock): - hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) - if vram_limit_level>=1: - res_stack[-1] = res_stack[-1].cpu() - elif isinstance(block, PopBlock): - if vram_limit_level>=1: - res_stack[-1] = res_stack[-1].to(device) - hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) - else: - hidden_states_input = hidden_states - hidden_states_output = [] - for batch_id in range(0, sample.shape[0], unet_batch_size): - batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) - hidden_states, _, _, _ = block(hidden_states_input[batch_id: batch_id_], time_emb, text_emb[batch_id: batch_id_], res_stack) - hidden_states_output.append(hidden_states) - hidden_states = torch.concat(hidden_states_output, dim=0) - # 4.2 AnimateDiff - if motion_modules is not None: - if block_id in motion_modules.call_block_id: - motion_module_id = motion_modules.call_block_id[block_id] - hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( - hidden_states, time_emb, text_emb, res_stack, - batch_size=1 - ) - # 4.3 ControlNet - if block_id == controlnet_insert_block_id and additional_res_stack is not None: - hidden_states += additional_res_stack.pop().to(device) - if vram_limit_level>=1: - res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)] - else: - res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)] - - # 5. output - hidden_states = unet.conv_norm_out(hidden_states) - hidden_states = unet.conv_act(hidden_states) - hidden_states = unet.conv_out(hidden_states) - - return hidden_states - - def lets_dance_with_long_video( unet: SDUNet, motion_modules: SDMotionModel = None, @@ -187,6 +96,11 @@ class SDVideoPipeline(torch.nn.Module): self.motion_modules = model_manager.motion_modules + def fetch_beautiful_prompt(self, model_manager: ModelManager): + if "beautiful_prompt" in model_manager.model: + self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"]) + + @staticmethod def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): pipe = SDVideoPipeline( @@ -196,6 +110,7 @@ class SDVideoPipeline(torch.nn.Module): ) pipe.fetch_main_models(model_manager) pipe.fetch_motion_modules(model_manager) + pipe.fetch_beautiful_prompt(model_manager) pipe.fetch_controlnet_models(model_manager, controlnet_config_units) return pipe @@ -248,12 +163,6 @@ class SDVideoPipeline(torch.nn.Module): progress_bar_cmd=tqdm, progress_bar_st=None, ): - # Encode prompts - prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device).cpu() - prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device).cpu() - prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1) - prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1) - # Prepare scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength) @@ -265,6 +174,12 @@ class SDVideoPipeline(torch.nn.Module): latents = self.encode_images(input_frames) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + # Encode prompts + prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True).cpu() + prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False).cpu() + prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1) + prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1) + # Prepare ControlNets if controlnet_frames is not None: controlnet_frames = torch.stack([ diff --git a/diffsynth/pipelines/stable_diffusion_xl.py b/diffsynth/pipelines/stable_diffusion_xl.py index c002289..f9dd481 100644 --- a/diffsynth/pipelines/stable_diffusion_xl.py +++ b/diffsynth/pipelines/stable_diffusion_xl.py @@ -1,4 +1,5 @@ -from ..models import ModelManager +from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder +# TODO: SDXL ControlNet from ..prompts import SDXLPrompter from ..schedulers import EnhancedDDIMScheduler import torch @@ -7,29 +8,77 @@ from PIL import Image import numpy as np -class SDXLPipeline(torch.nn.Module): +class SDXLImagePipeline(torch.nn.Module): - def __init__(self): + def __init__(self, device="cuda", torch_dtype=torch.float16): super().__init__() self.scheduler = EnhancedDDIMScheduler() + self.prompter = SDXLPrompter() + self.device = device + self.torch_dtype = torch_dtype + # models + self.text_encoder: SDXLTextEncoder = None + self.text_encoder_2: SDXLTextEncoder2 = None + self.unet: SDXLUNet = None + self.vae_decoder: SDXLVAEDecoder = None + self.vae_encoder: SDXLVAEEncoder = None + # TODO: SDXL ControlNet + def fetch_main_models(self, model_manager: ModelManager): + self.text_encoder = model_manager.text_encoder + self.text_encoder_2 = model_manager.text_encoder_2 + self.unet = model_manager.unet + self.vae_decoder = model_manager.vae_decoder + self.vae_encoder = model_manager.vae_encoder + # load textual inversion + self.prompter.load_textual_inversion(model_manager.textual_inversion_dict) + + + def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs): + # TODO: SDXL ControlNet + pass + + + def fetch_beautiful_prompt(self, model_manager: ModelManager): + if "beautiful_prompt" in model_manager.model: + self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"]) + + + @staticmethod + def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs): + pipe = SDXLImagePipeline( + device=model_manager.device, + torch_dtype=model_manager.torch_dtype, + ) + pipe.fetch_main_models(model_manager) + pipe.fetch_beautiful_prompt(model_manager) + pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units) + return pipe + + def preprocess_image(self, image): image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) return image + + def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): + image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = image.cpu().permute(1, 2, 0).numpy() + image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + return image + + @torch.no_grad() def __call__( self, - model_manager: ModelManager, - prompter: SDXLPrompter, prompt, negative_prompt="", cfg_scale=7.5, clip_skip=1, clip_skip_2=2, - init_image=None, + input_image=None, + controlnet_image=None, denoising_strength=1.0, - refining_strength=0.0, height=1024, width=1024, num_inference_steps=20, @@ -39,76 +88,62 @@ class SDXLPipeline(torch.nn.Module): progress_bar_cmd=tqdm, progress_bar_st=None, ): - # Encode prompts - add_text_embeds, prompt_emb = prompter.encode_prompt( - model_manager.text_encoder, - model_manager.text_encoder_2, - prompt, - clip_skip=clip_skip, clip_skip_2=clip_skip_2, - device=model_manager.device - ) - if cfg_scale != 1.0: - negative_add_text_embeds, negative_prompt_emb = prompter.encode_prompt( - model_manager.text_encoder, - model_manager.text_encoder_2, - negative_prompt, - clip_skip=clip_skip, clip_skip_2=clip_skip_2, - device=model_manager.device - ) - # Prepare scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength) # Prepare latent tensors - if init_image is not None: - image = self.preprocess_image(init_image).to( - device=model_manager.device, dtype=model_manager.torch_type - ) - latents = model_manager.vae_encoder( - image.to(torch.float32), - tiled=tiled, tile_size=tile_size, tile_stride=tile_stride - ) - noise = torch.randn( - (1, 4, height//8, width//8), - device=model_manager.device, dtype=model_manager.torch_type - ) - latents = self.scheduler.add_noise( - latents.to(model_manager.torch_type), - noise, - timestep=self.scheduler.timesteps[0] - ) + if input_image is not None: + image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) + latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype) + noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) + latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) else: - latents = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type) + latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) + + # Encode prompts + add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt( + self.text_encoder, + self.text_encoder_2, + prompt, + clip_skip=clip_skip, clip_skip_2=clip_skip_2, + device=self.device + ) + if cfg_scale != 1.0: + add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt( + self.text_encoder, + self.text_encoder_2, + negative_prompt, + clip_skip=clip_skip, clip_skip_2=clip_skip_2, + device=self.device + ) + + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength) # Prepare positional id - add_time_id = torch.tensor([height, width, 0, 0, height, width], device=model_manager.device) + add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) # Denoise for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): - timestep = torch.IntTensor((timestep,))[0].to(model_manager.device) + timestep = torch.IntTensor((timestep,))[0].to(self.device) # Classifier-free guidance - if timestep >= 1000 * refining_strength: - denoising_model = model_manager.unet - else: - denoising_model = model_manager.refiner - if cfg_scale != 1.0: - noise_pred_cond = denoising_model( - latents, timestep, prompt_emb, - add_time_id=add_time_id, add_text_embeds=add_text_embeds, + noise_pred_posi = self.unet( + latents, timestep, prompt_emb_posi, + add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride ) - noise_pred_uncond = denoising_model( - latents, timestep, negative_prompt_emb, - add_time_id=add_time_id, add_text_embeds=negative_add_text_embeds, + noise_pred_nega = self.unet( + latents, timestep, prompt_emb_nega, + add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride ) - noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: - noise_pred = denoising_model( - latents, timestep, prompt_emb, - add_time_id=add_time_id, add_text_embeds=add_text_embeds, + noise_pred = self.unet( + latents, timestep, prompt_emb_posi, + add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride ) @@ -118,9 +153,6 @@ class SDXLPipeline(torch.nn.Module): progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) # Decode image - latents = latents.to(torch.float32) - image = model_manager.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] - image = image.cpu().permute(1, 2, 0).numpy() - image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return image diff --git a/diffsynth/prompts/__init__.py b/diffsynth/prompts/__init__.py index af4e2d0..be94f8b 100644 --- a/diffsynth/prompts/__init__.py +++ b/diffsynth/prompts/__init__.py @@ -1,5 +1,5 @@ -from transformers import CLIPTokenizer -from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, load_state_dict +from transformers import CLIPTokenizer, AutoTokenizer +from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2 import torch, os @@ -35,14 +35,30 @@ def tokenize_long_prompt(tokenizer, prompt): return input_ids -def search_for_embeddings(state_dict): - embeddings = [] - for k in state_dict: - if isinstance(state_dict[k], torch.Tensor): - embeddings.append(state_dict[k]) - elif isinstance(state_dict[k], dict): - embeddings += search_for_embeddings(state_dict[k]) - return embeddings +class BeautifulPrompt: + def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self.model = model + self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:' + + def __call__(self, raw_prompt): + model_input = self.template.format(raw_prompt=raw_prompt) + input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device) + outputs = self.model.generate( + input_ids, + max_new_tokens=384, + do_sample=True, + temperature=0.9, + top_k=50, + top_p=0.95, + repetition_penalty=1.1, + num_return_sequences=1 + ) + prompt = raw_prompt + ", " + self.tokenizer.batch_decode( + outputs[:, input_ids.size(1):], + skip_special_tokens=True + )[0].strip() + return prompt class SDPrompter: @@ -50,11 +66,20 @@ class SDPrompter: # We use the tokenizer implemented by transformers self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) self.keyword_dict = {} + self.beautiful_prompt: BeautifulPrompt = None - def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda"): + + def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True): + # Textual Inversion for keyword in self.keyword_dict: if keyword in prompt: prompt = prompt.replace(keyword, self.keyword_dict[keyword]) + + # Beautiful Prompt + if positive and self.beautiful_prompt is not None: + prompt = self.beautiful_prompt(prompt) + print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"") + input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) prompt_emb = text_encoder(input_ids, clip_skip=clip_skip) prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) @@ -70,6 +95,16 @@ class SDPrompter: self.keyword_dict[keyword] = " " + " ".join(tokens) + " " self.tokenizer.add_tokens(additional_tokens) + def load_beautiful_prompt(self, model, model_path): + model_folder = os.path.dirname(model_path) + self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model) + if model_folder.endswith("v2"): + self.beautiful_prompt.template = """Converts a simple image description into a prompt. \ +Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \ +or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \ +but make sure there is a correlation between the input and output.\n\ +### Input: {raw_prompt}\n### Output:""" + class SDXLPrompter: def __init__( @@ -80,6 +115,8 @@ class SDXLPrompter: # We use the tokenizer implemented by transformers self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path) + self.keyword_dict = {} + self.beautiful_prompt: BeautifulPrompt = None def encode_prompt( self, @@ -88,8 +125,19 @@ class SDXLPrompter: prompt, clip_skip=1, clip_skip_2=2, + positive=True, device="cuda" ): + # Textual Inversion + for keyword in self.keyword_dict: + if keyword in prompt: + prompt = prompt.replace(keyword, self.keyword_dict[keyword]) + + # Beautiful Prompt + if positive and self.beautiful_prompt is not None: + prompt = self.beautiful_prompt(prompt) + print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"") + # 1 input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip) @@ -105,3 +153,22 @@ class SDXLPrompter: add_text_embeds = add_text_embeds[0:1] prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) return add_text_embeds, prompt_emb + + def load_textual_inversion(self, textual_inversion_dict): + self.keyword_dict = {} + additional_tokens = [] + for keyword in textual_inversion_dict: + tokens, _ = textual_inversion_dict[keyword] + additional_tokens += tokens + self.keyword_dict[keyword] = " " + " ".join(tokens) + " " + self.tokenizer.add_tokens(additional_tokens) + + def load_beautiful_prompt(self, model, model_path): + model_folder = os.path.dirname(model_path) + self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model) + if model_folder.endswith("v2"): + self.beautiful_prompt.template = """Converts a simple image description into a prompt. \ +Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \ +or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \ +but make sure there is a correlation between the input and output.\n\ +### Input: {raw_prompt}\n### Output:""" diff --git a/diffsynth/prompts/sd_tokenizer.py b/diffsynth/prompts/sd_tokenizer.py deleted file mode 100644 index 1485138..0000000 --- a/diffsynth/prompts/sd_tokenizer.py +++ /dev/null @@ -1,37 +0,0 @@ -from transformers import CLIPTokenizer - -class SDTokenizer: - def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"): - # We use the tokenizer implemented by transformers - self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) - - def __call__(self, prompt): - # Get model_max_length from self.tokenizer - length = self.tokenizer.model_max_length - - # To avoid the warning. set self.tokenizer.model_max_length to +oo. - self.tokenizer.model_max_length = 99999999 - - # Tokenize it! - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids - - # Determine the real length. - max_length = (input_ids.shape[1] + length - 1) // length * length - - # Restore self.tokenizer.model_max_length - self.tokenizer.model_max_length = length - - # Tokenize it again with fixed length. - input_ids = self.tokenizer( - prompt, - return_tensors="pt", - padding="max_length", - max_length=max_length, - truncation=True - ).input_ids - - # Reshape input_ids to fit the text encoder. - num_sentence = input_ids.shape[1] // length - input_ids = input_ids.reshape((num_sentence, length)) - - return input_ids diff --git a/diffsynth/prompts/sdxl_tokenizer.py b/diffsynth/prompts/sdxl_tokenizer.py deleted file mode 100644 index 797924d..0000000 --- a/diffsynth/prompts/sdxl_tokenizer.py +++ /dev/null @@ -1,45 +0,0 @@ -from transformers import CLIPTokenizer -from .sd_tokenizer import SDTokenizer - - -class SDXLTokenizer(SDTokenizer): - def __init__(self): - super().__init__() - - -class SDXLTokenizer2: - def __init__(self, tokenizer_path="configs/stable_diffusion_xl/tokenizer_2"): - # We use the tokenizer implemented by transformers - self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) - - def __call__(self, prompt): - # Get model_max_length from self.tokenizer - length = self.tokenizer.model_max_length - - # To avoid the warning. set self.tokenizer.model_max_length to +oo. - self.tokenizer.model_max_length = 99999999 - - # Tokenize it! - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids - - # Determine the real length. - max_length = (input_ids.shape[1] + length - 1) // length * length - - # Restore self.tokenizer.model_max_length - self.tokenizer.model_max_length = length - - # Tokenize it again with fixed length. - input_ids = self.tokenizer( - prompt, - return_tensors="pt", - padding="max_length", - max_length=max_length, - truncation=True - ).input_ids - - # Reshape input_ids to fit the text encoder. - num_sentence = input_ids.shape[1] // length - input_ids = input_ids.reshape((num_sentence, length)) - - return input_ids - diff --git a/diffsynth/schedulers/__init__.py b/diffsynth/schedulers/__init__.py index a9336ff..303fffe 100644 --- a/diffsynth/schedulers/__init__.py +++ b/diffsynth/schedulers/__init__.py @@ -18,7 +18,7 @@ class EnhancedDDIMScheduler(): def set_timesteps(self, num_inference_steps, denoising_strength=1.0): # The timesteps are aligned to 999...0, which is different from other implementations, # but I think this implementation is more reasonable in theory. - max_timestep = round(self.num_train_timesteps * denoising_strength) - 1 + max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0) num_inference_steps = min(num_inference_steps, max_timestep + 1) if num_inference_steps == 1: self.timesteps = [max_timestep] diff --git a/examples/sd_text_to_image.py b/examples/sd_text_to_image.py new file mode 100644 index 0000000..76e65d3 --- /dev/null +++ b/examples/sd_text_to_image.py @@ -0,0 +1,75 @@ +from diffsynth import ModelManager, SDImagePipeline, ControlNetConfigUnit +import torch + + +# Download models +# `models/stable_diffusion/aingdiffusion_v12.safetensors`: [link](https://civitai.com/api/download/models/229575?type=Model&format=SafeTensor&size=full&fp=fp16) +# `models/ControlNet/control_v11p_sd15_lineart.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_lineart.pth) +# `models/ControlNet/control_v11f1e_sd15_tile.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth) +# `models/Annotators/sk_model.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model.pth) +# `models/Annotators/sk_model2.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model2.pth) + + +# Load models +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") +model_manager.load_textual_inversions("models/textual_inversion") +model_manager.load_models([ + "models/stable_diffusion/aingdiffusion_v12.safetensors", + "models/ControlNet/control_v11f1e_sd15_tile.pth", + "models/ControlNet/control_v11p_sd15_lineart.pth" +]) +pipe = SDImagePipeline.from_model_manager( + model_manager, + [ + ControlNetConfigUnit( + processor_id="tile", + model_path=rf"models/ControlNet/control_v11f1e_sd15_tile.pth", + scale=0.5 + ), + ControlNetConfigUnit( + processor_id="lineart", + model_path=rf"models/ControlNet/control_v11p_sd15_lineart.pth", + scale=0.7 + ), + ] +) + +prompt = "masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait," +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,", + +torch.manual_seed(0) +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=7.5, clip_skip=1, + height=512, width=512, num_inference_steps=80, +) +image.save("512.jpg") + +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=7.5, clip_skip=1, + input_image=image.resize((1024, 1024)), controlnet_image=image.resize((1024, 1024)), + height=1024, width=1024, num_inference_steps=40, denoising_strength=0.7, +) +image.save("1024.jpg") + +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=7.5, clip_skip=1, + input_image=image.resize((2048, 2048)), controlnet_image=image.resize((2048, 2048)), + height=2048, width=2048, num_inference_steps=20, denoising_strength=0.7, +) +image.save("2048.jpg") + +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=7.5, clip_skip=1, + input_image=image.resize((4096, 4096)), controlnet_image=image.resize((4096, 4096)), + height=4096, width=4096, num_inference_steps=10, denoising_strength=0.5, + tiled=True, tile_size=128, tile_stride=64 +) +image.save("4096.jpg") diff --git a/examples/sd_toon_shading.py b/examples/sd_toon_shading.py new file mode 100644 index 0000000..df9f3a9 --- /dev/null +++ b/examples/sd_toon_shading.py @@ -0,0 +1,56 @@ +from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames +import torch + + +# Download models +# `models/stable_diffusion/flat2DAnimerge_v45Sharp.safetensors`: [link](https://civitai.com/api/download/models/266360?type=Model&format=SafeTensor&size=pruned&fp=fp16) +# `models/AnimateDiff/mm_sd_v15_v2.ckpt`: [link](https://huggingface.co/guoyww/animatediff/resolve/main/mm_sd_v15_v2.ckpt) +# `models/ControlNet/control_v11p_sd15_lineart.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_lineart.pth) +# `models/ControlNet/control_v11f1e_sd15_tile.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth) +# `models/Annotators/sk_model.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model.pth) +# `models/Annotators/sk_model2.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model2.pth) + + +# Load models +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") +model_manager.load_textual_inversions("models/textual_inversion") +model_manager.load_models([ + "models/stable_diffusion/flat2DAnimerge_v45Sharp.safetensors", + "models/AnimateDiff/mm_sd_v15_v2.ckpt", + "models/ControlNet/control_v11p_sd15_lineart.pth", + "models/ControlNet/control_v11f1e_sd15_tile.pth", +]) +pipe = SDVideoPipeline.from_model_manager( + model_manager, + [ + ControlNetConfigUnit( + processor_id="lineart", + model_path="models/ControlNet/control_v11p_sd15_lineart.pth", + scale=1.0 + ), + ControlNetConfigUnit( + processor_id="tile", + model_path="models/ControlNet/control_v11f1e_sd15_tile.pth", + scale=0.5 + ), + ] +) + +# Load video (we only use 16 frames in this example for testing) +video = VideoData(video_file="input_video.mp4", height=1536, width=1536) +input_video = [video[i] for i in range(16)] + +# Toon shading +torch.manual_seed(0) +output_video = pipe( + prompt="best quality, perfect anime illustration, light, a girl is dancing, smile, solo", + negative_prompt="verybadimagenegative_v1.3", + cfg_scale=5, clip_skip=2, + controlnet_frames=input_video, num_frames=len(input_video), + num_inference_steps=10, height=1536, width=1536, + vram_limit_level=0, +) + +# Save images and video +save_frames(output_video, "output_frames") +save_video(output_video, "output_video.mp4", fps=16) diff --git a/examples/sdxl_text_to_image.py b/examples/sdxl_text_to_image.py new file mode 100644 index 0000000..16df873 --- /dev/null +++ b/examples/sdxl_text_to_image.py @@ -0,0 +1,34 @@ +from diffsynth import ModelManager, SDXLImagePipeline +import torch + + +# Download models +# `models/stable_diffusion_xl/bluePencilXL_v200.safetensors`: [link](https://civitai.com/api/download/models/245614?type=Model&format=SafeTensor&size=pruned&fp=fp16) + + +# Load models +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") +model_manager.load_models(["models/stable_diffusion_xl/bluePencilXL_v200.safetensors"]) +pipe = SDXLImagePipeline.from_model_manager(model_manager) + +prompt = "masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait," +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,", + +torch.manual_seed(0) +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=6, + height=1024, width=1024, num_inference_steps=60, +) +image.save("1024.jpg") + +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=6, + input_image=image.resize((2048, 2048)), + height=2048, width=2048, num_inference_steps=60, denoising_strength=0.5 +) +image.save("2048.jpg") + diff --git a/examples/sdxl_turbo.py b/examples/sdxl_turbo.py new file mode 100644 index 0000000..8d40512 --- /dev/null +++ b/examples/sdxl_turbo.py @@ -0,0 +1,31 @@ +from diffsynth import ModelManager, SDXLImagePipeline +import torch + + +# Download models +# `models/stable_diffusion_xl_turbo/sd_xl_turbo_1.0_fp16.safetensors`: [link](https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/sd_xl_turbo_1.0_fp16.safetensors) + + +# Load models +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") +model_manager.load_models(["models/stable_diffusion_xl_turbo/sd_xl_turbo_1.0_fp16.safetensors"]) +pipe = SDXLImagePipeline.from_model_manager(model_manager) + +# Text to image +torch.manual_seed(0) +image = pipe( + prompt="black car", + # Do not modify the following parameters! + cfg_scale=1, height=512, width=512, num_inference_steps=1, progress_bar_cmd=lambda x:x +) +image.save(f"black_car.jpg") + +# Image to image +torch.manual_seed(0) +image = pipe( + prompt="red car", + input_image=image, denoising_strength=0.7, + # Do not modify the following parameters! + cfg_scale=1, height=512, width=512, num_inference_steps=1, progress_bar_cmd=lambda x:x +) +image.save(f"black_car_to_red_car.jpg") diff --git a/models/stable_diffusion_xl_turbo/Put Stable Diffusion XL Turbo checkpoints here.txt b/models/stable_diffusion_xl_turbo/Put Stable Diffusion XL Turbo checkpoints here.txt new file mode 100644 index 0000000..e69de29 diff --git a/pages/1_Image_Creator.py b/pages/1_Image_Creator.py new file mode 100644 index 0000000..6063b85 --- /dev/null +++ b/pages/1_Image_Creator.py @@ -0,0 +1,261 @@ +import torch, os, io +import numpy as np +from PIL import Image +import streamlit as st +st.set_page_config(layout="wide") +from streamlit_drawable_canvas import st_canvas +from diffsynth.models import ModelManager +from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline + + +config = { + "Stable Diffusion": { + "model_folder": "models/stable_diffusion", + "pipeline_class": SDImagePipeline, + "fixed_parameters": {} + }, + "Stable Diffusion XL": { + "model_folder": "models/stable_diffusion_xl", + "pipeline_class": SDXLImagePipeline, + "fixed_parameters": {} + }, + "Stable Diffusion XL Turbo": { + "model_folder": "models/stable_diffusion_xl_turbo", + "pipeline_class": SDXLImagePipeline, + "fixed_parameters": { + "negative_prompt": "", + "cfg_scale": 1.0, + "num_inference_steps": 1, + "height": 512, + "width": 512, + } + } +} + + +def load_model_list(model_type): + folder = config[model_type]["model_folder"] + file_list = os.listdir(folder) + file_list = [i for i in file_list if i.endswith(".safetensors")] + file_list = sorted(file_list) + return file_list + + +def release_model(): + if "model_manager" in st.session_state: + st.session_state["model_manager"].to("cpu") + del st.session_state["loaded_model_path"] + del st.session_state["model_manager"] + del st.session_state["pipeline"] + torch.cuda.empty_cache() + + +def load_model(model_type, model_path): + model_manager = ModelManager() + model_manager.load_model(model_path) + pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager) + st.session_state.loaded_model_path = model_path + st.session_state.model_manager = model_manager + st.session_state.pipeline = pipeline + return model_manager, pipeline + + +def use_output_image_as_input(): + # Search for input image + output_image_id = 0 + selected_output_image = None + while True: + if f"use_output_as_input_{output_image_id}" not in st.session_state: + break + if st.session_state[f"use_output_as_input_{output_image_id}"]: + selected_output_image = st.session_state["output_images"][output_image_id] + break + output_image_id += 1 + if selected_output_image is not None: + st.session_state["input_image"] = selected_output_image + + +def apply_stroke_to_image(stroke_image, image): + image = np.array(image.convert("RGB")).astype(np.float32) + height, width, _ = image.shape + + stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32) + weight = stroke_image[:, :, -1:] / 255 + stroke_image = stroke_image[:, :, :-1] + + image = stroke_image * weight + image * (1 - weight) + image = np.clip(image, 0, 255).astype(np.uint8) + image = Image.fromarray(image) + return image + + +@st.cache_data +def image2bits(image): + image_byte = io.BytesIO() + image.save(image_byte, format="PNG") + image_byte = image_byte.getvalue() + return image_byte + + +def show_output_image(image): + st.image(image, use_column_width="always") + st.button("Use it as input image", key=f"use_output_as_input_{image_id}") + st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}") + + +column_input, column_output = st.columns(2) +with st.sidebar: + # Select a model + with st.expander("Model", expanded=True): + model_type = st.selectbox("Model type", ["Stable Diffusion", "Stable Diffusion XL", "Stable Diffusion XL Turbo"]) + fixed_parameters = config[model_type]["fixed_parameters"] + model_path_list = ["None"] + load_model_list(model_type) + model_path = st.selectbox("Model path", model_path_list) + + # Load the model + if model_path == "None": + # No models are selected. Release VRAM. + st.markdown("No models are selected.") + release_model() + else: + # A model is selected. + model_path = os.path.join(config[model_type]["model_folder"], model_path) + if st.session_state.get("loaded_model_path", "") != model_path: + # The loaded model is not the selected model. Reload it. + st.markdown(f"Using model at {model_path}.") + release_model() + model_manager, pipeline = load_model(model_type, model_path) + else: + # The loaded model is not the selected model. Fetch it from `st.session_state`. + st.markdown(f"Using model at {model_path}.") + model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline + + # Show parameters + with st.expander("Prompt", expanded=True): + prompt = st.text_area("Positive prompt") + if "negative_prompt" in fixed_parameters: + negative_prompt = fixed_parameters["negative_prompt"] + else: + negative_prompt = st.text_area("Negative prompt") + if "cfg_scale" in fixed_parameters: + cfg_scale = fixed_parameters["cfg_scale"] + else: + cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.5) + with st.expander("Image", expanded=True): + if "num_inference_steps" in fixed_parameters: + num_inference_steps = fixed_parameters["num_inference_steps"] + else: + num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20) + if "height" in fixed_parameters: + height = fixed_parameters["height"] + else: + height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512) + if "width" in fixed_parameters: + width = fixed_parameters["width"] + else: + width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512) + num_images = st.number_input("Number of images", value=2) + use_fixed_seed = st.checkbox("Use fixed seed", value=False) + if use_fixed_seed: + seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0) + + # Other fixed parameters + denoising_strength = 1.0 + repetition = 1 + + +# Show input image +with column_input: + with st.expander("Input image (Optional)", expanded=True): + with st.container(border=True): + column_white_board, column_upload_image = st.columns([1, 2]) + with column_white_board: + create_white_board = st.button("Create white board") + delete_input_image = st.button("Delete input image") + with column_upload_image: + upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image") + + if upload_image is not None: + st.session_state["input_image"] = Image.open(upload_image) + elif create_white_board: + st.session_state["input_image"] = Image.fromarray(np.ones((height, width, 3), dtype=np.uint8) * 255) + else: + use_output_image_as_input() + + if delete_input_image and "input_image" in st.session_state: + del st.session_state.input_image + if delete_input_image and "upload_image" in st.session_state: + del st.session_state.upload_image + + input_image = st.session_state.get("input_image", None) + if input_image is not None: + with st.container(border=True): + column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1]) + with column_drawing_mode: + drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1) + with column_color_1: + stroke_color = st.color_picker("Stroke color") + with column_color_2: + fill_color = st.color_picker("Fill color") + stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10) + with st.container(border=True): + denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7) + repetition = st.slider("Repetition", min_value=1, max_value=8, value=1) + with st.container(border=True): + input_width, input_height = input_image.size + canvas_result = st_canvas( + fill_color=fill_color, + stroke_width=stroke_width, + stroke_color=stroke_color, + background_color="rgba(255, 255, 255, 0)", + background_image=input_image, + update_streamlit=True, + height=int(512 / input_width * input_height), + width=512, + drawing_mode=drawing_mode, + key="canvas" + ) + + +with column_output: + run_button = st.button("Generate image", type="primary") + auto_update = st.checkbox("Auto update", value=False) + num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2) + image_columns = st.columns(num_image_columns) + + # Run + if (run_button or auto_update) and model_path != "None": + + if input_image is not None: + input_image = input_image.resize((width, height)) + if canvas_result.image_data is not None: + input_image = apply_stroke_to_image(canvas_result.image_data, input_image) + + output_images = [] + for image_id in range(num_images * repetition): + if use_fixed_seed: + torch.manual_seed(seed + image_id) + else: + torch.manual_seed(np.random.randint(0, 10**9)) + if image_id >= num_images: + input_image = output_images[image_id - num_images] + with image_columns[image_id % num_image_columns]: + progress_bar_st = st.progress(0.0) + image = pipeline( + prompt, negative_prompt=negative_prompt, + cfg_scale=cfg_scale, num_inference_steps=num_inference_steps, + height=height, width=width, + input_image=input_image, denoising_strength=denoising_strength, + progress_bar_st=progress_bar_st + ) + output_images.append(image) + progress_bar_st.progress(1.0) + show_output_image(image) + st.session_state["output_images"] = output_images + + elif "output_images" in st.session_state: + for image_id in range(len(st.session_state.output_images)): + with image_columns[image_id % num_image_columns]: + image = st.session_state.output_images[image_id] + progress_bar = st.progress(1.0) + show_output_image(image) diff --git a/pages/2_Video_Creator.py b/pages/2_Video_Creator.py new file mode 100644 index 0000000..08e2f5f --- /dev/null +++ b/pages/2_Video_Creator.py @@ -0,0 +1,4 @@ +import streamlit as st +st.set_page_config(layout="wide") + +st.markdown("# Coming soon")