Files
DiffSynth-Studio/Diffsynth_Studio.py
2023-12-08 01:03:30 +08:00

265 lines
11 KiB
Python

import torch, os, io
import numpy as np
from PIL import Image
import streamlit as st
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from diffsynth.models import ModelManager
from diffsynth.prompts import SDXLPrompter, SDPrompter
from diffsynth.pipelines import SDXLPipeline, SDPipeline
torch.cuda.set_per_process_memory_fraction(0.999, 0)
@st.cache_data
def load_model_list(folder):
file_list = os.listdir(folder)
file_list = [i for i in file_list if i.endswith(".safetensors")]
file_list = sorted(file_list)
return file_list
def detect_model_path(sd_model_path, sdxl_model_path):
if sd_model_path != "None":
model_path = os.path.join("models/stable_diffusion", sd_model_path)
elif sdxl_model_path != "None":
model_path = os.path.join("models/stable_diffusion_xl", sdxl_model_path)
else:
model_path = None
return model_path
def load_model(sd_model_path, sdxl_model_path):
if sd_model_path != "None":
model_path = os.path.join("models/stable_diffusion", sd_model_path)
model_manager = ModelManager()
model_manager.load_from_safetensors(model_path)
prompter = SDPrompter()
pipeline = SDPipeline()
elif sdxl_model_path != "None":
model_path = os.path.join("models/stable_diffusion_xl", sdxl_model_path)
model_manager = ModelManager()
model_manager.load_from_safetensors(model_path)
prompter = SDXLPrompter()
pipeline = SDXLPipeline()
else:
return None, None, None, None
return model_path, model_manager, prompter, pipeline
def release_model():
if "model_manager" in st.session_state:
st.session_state["model_manager"].to("cpu")
del st.session_state["loaded_model_path"]
del st.session_state["model_manager"]
del st.session_state["prompter"]
del st.session_state["pipeline"]
torch.cuda.empty_cache()
def use_output_image_as_input():
# Search for input image
output_image_id = 0
selected_output_image = None
while True:
if f"use_output_as_input_{output_image_id}" not in st.session_state:
break
if st.session_state[f"use_output_as_input_{output_image_id}"]:
selected_output_image = st.session_state["output_images"][output_image_id]
break
output_image_id += 1
if selected_output_image is not None:
st.session_state["input_image"] = selected_output_image
def apply_stroke_to_image(stroke_image, image):
image = np.array(image.convert("RGB")).astype(np.float32)
height, width, _ = image.shape
stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32)
weight = stroke_image[:, :, -1:] / 255
stroke_image = stroke_image[:, :, :-1]
image = stroke_image * weight + image * (1 - weight)
image = np.clip(image, 0, 255).astype(np.uint8)
image = Image.fromarray(image)
return image
@st.cache_data
def image2bits(image):
image_byte = io.BytesIO()
image.save(image_byte, format="PNG")
image_byte = image_byte.getvalue()
return image_byte
def show_output_image(image):
st.image(image, use_column_width="always")
st.button("Use it as input image", key=f"use_output_as_input_{image_id}")
st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}")
column_input, column_output = st.columns(2)
# with column_input:
with st.sidebar:
# Select a model
with st.expander("Model", expanded=True):
sd_model_list = ["None"] + load_model_list("models/stable_diffusion")
sd_model_path = st.selectbox(
"Stable Diffusion", sd_model_list
)
sdxl_model_list = ["None"] + load_model_list("models/stable_diffusion_xl")
sdxl_model_path = st.selectbox(
"Stable Diffusion XL", sdxl_model_list
)
# Load the model
model_path = detect_model_path(sd_model_path, sdxl_model_path)
if model_path is None:
st.markdown("No models selected.")
release_model()
elif st.session_state.get("loaded_model_path", "") != model_path:
st.markdown(f"Using model at {model_path}.")
release_model()
model_path, model_manager, prompter, pipeline = load_model(sd_model_path, sdxl_model_path)
st.session_state.loaded_model_path = model_path
st.session_state.model_manager = model_manager
st.session_state.prompter = prompter
st.session_state.pipeline = pipeline
else:
st.markdown(f"Using model at {model_path}.")
model_path, model_manager, prompter, pipeline = (
st.session_state.loaded_model_path,
st.session_state.model_manager,
st.session_state.prompter,
st.session_state.pipeline,
)
# Show parameters
with st.expander("Prompt", expanded=True):
column_positive, column_negative = st.columns(2)
prompt = st.text_area("Positive prompt")
negative_prompt = st.text_area("Negative prompt")
with st.expander("Classifier-free guidance", expanded=True):
use_cfg = st.checkbox("Use classifier-free guidance", value=True)
if use_cfg:
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, step=0.1, value=7.5)
else:
cfg_scale = 1.0
with st.expander("Inference steps", expanded=True):
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20, label_visibility="hidden")
with st.expander("Image size", expanded=True):
height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
with st.expander("Seed", expanded=True):
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
if use_fixed_seed:
seed = st.number_input("Random seed", value=0, label_visibility="hidden")
with st.expander("Number of images", expanded=True):
num_images = st.number_input("Number of images", value=4, label_visibility="hidden")
with st.expander("Tile (for high resolution)", expanded=True):
tiled = st.checkbox("Use tile", value=False)
tile_size = st.select_slider("Tile size", options=[64, 128], value=64)
tile_stride = st.select_slider("Tile stride", options=[8, 16, 32, 64], value=32)
# Show input image
with column_input:
with st.expander("Input image (Optional)", expanded=True):
with st.container(border=True):
column_white_board, column_upload_image = st.columns([1, 2])
with column_white_board:
create_white_board = st.button("Create white board")
delete_input_image = st.button("Delete input image")
with column_upload_image:
upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image")
if upload_image is not None:
st.session_state["input_image"] = Image.open(upload_image)
elif create_white_board:
st.session_state["input_image"] = Image.fromarray(np.ones((1024, 1024, 3), dtype=np.uint8) * 255)
else:
use_output_image_as_input()
if delete_input_image and "input_image" in st.session_state:
del st.session_state.input_image
if delete_input_image and "upload_image" in st.session_state:
del st.session_state.upload_image
input_image = st.session_state.get("input_image", None)
if input_image is not None:
with st.container(border=True):
column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1])
with column_drawing_mode:
drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1)
with column_color_1:
stroke_color = st.color_picker("Stroke color")
with column_color_2:
fill_color = st.color_picker("Fill color")
stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
with st.container(border=True):
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
with st.container(border=True):
input_width, input_height = input_image.size
canvas_result = st_canvas(
fill_color=fill_color,
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color="rgba(255, 255, 255, 0)",
background_image=input_image,
update_streamlit=True,
height=int(512 / input_width * input_height),
width=512,
drawing_mode=drawing_mode,
key="canvas"
)
with column_output:
run_button = st.button("Generate image", type="primary")
auto_update = st.checkbox("Auto update", value=False)
num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2)
image_columns = st.columns(num_image_columns)
# Run
if (run_button or auto_update) and model_path is not None:
if not use_fixed_seed:
torch.manual_seed(np.random.randint(0, 10**9))
output_images = []
for image_id in range(num_images):
if use_fixed_seed:
torch.manual_seed(seed + image_id)
if input_image is not None:
input_image = input_image.resize((width, height))
if canvas_result.image_data is not None:
input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
else:
denoising_strength = 1.0
with image_columns[image_id % num_image_columns]:
progress_bar = st.progress(0.0)
image = pipeline(
model_manager, prompter,
prompt, negative_prompt=negative_prompt, cfg_scale=cfg_scale,
num_inference_steps=num_inference_steps,
height=height, width=width,
init_image=input_image, denoising_strength=denoising_strength,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
progress_bar_st=progress_bar
)
output_images.append(image)
progress_bar.progress(1.0)
show_output_image(image)
st.session_state["output_images"] = output_images
elif "output_images" in st.session_state:
for image_id in range(len(st.session_state.output_images)):
with image_columns[image_id % num_image_columns]:
image = st.session_state.output_images[image_id]
progress_bar = st.progress(1.0)
show_output_image(image)