From f6de5eef4d82038aa2954fda7ea0a35865885e82 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 6 Jun 2024 15:56:24 +0800 Subject: [PATCH] Hunyuan DiT UI --- environment.yml | 1 + pages/1_Image_Creator.py | 37 +++++++++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/environment.yml b/environment.yml index eb0edf4..ba447dc 100644 --- a/environment.yml +++ b/environment.yml @@ -18,3 +18,4 @@ dependencies: - imageio[ffmpeg] - safetensors - einops + - sentencepiece diff --git a/pages/1_Image_Creator.py b/pages/1_Image_Creator.py index 8c735fa..2b50072 100644 --- a/pages/1_Image_Creator.py +++ b/pages/1_Image_Creator.py @@ -5,7 +5,7 @@ import streamlit as st st.set_page_config(layout="wide") from streamlit_drawable_canvas import st_canvas from diffsynth.models import ModelManager -from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline +from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, HunyuanDiTImagePipeline from diffsynth.data.video import crop_and_resize @@ -30,14 +30,23 @@ config = { "height": 512, "width": 512, } - } + }, + "HunyuanDiT": { + "model_folder": "models/HunyuanDiT", + "pipeline_class": HunyuanDiTImagePipeline, + "fixed_parameters": { + "height": 1024, + "width": 1024, + } + }, } def load_model_list(model_type): folder = config[model_type]["model_folder"] - file_list = os.listdir(folder) - file_list = [i for i in file_list if i.endswith(".safetensors")] + file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] + if model_type == "HunyuanDiT": + file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] file_list = sorted(file_list) return file_list @@ -53,7 +62,15 @@ def release_model(): def load_model(model_type, model_path): model_manager = ModelManager() - model_manager.load_model(model_path) + if model_type == "HunyuanDiT": + model_manager.load_models([ + os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"), + os.path.join(model_path, "mt5/pytorch_model.bin"), + os.path.join(model_path, "model/pytorch_model_ema.pt"), + os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"), + ]) + else: + model_manager.load_model(model_path) pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager) st.session_state.loaded_model_path = model_path st.session_state.model_manager = model_manager @@ -109,7 +126,7 @@ column_input, column_output = st.columns(2) with st.sidebar: # Select a model with st.expander("Model", expanded=True): - model_type = st.selectbox("Model type", ["Stable Diffusion", "Stable Diffusion XL", "Stable Diffusion XL Turbo"]) + model_type = st.selectbox("Model type", [model_type_ for model_type_ in config]) fixed_parameters = config[model_type]["fixed_parameters"] model_path_list = ["None"] + load_model_list(model_type) model_path = st.selectbox("Model path", model_path_list) @@ -124,13 +141,17 @@ with st.sidebar: model_path = os.path.join(config[model_type]["model_folder"], model_path) if st.session_state.get("loaded_model_path", "") != model_path: # The loaded model is not the selected model. Reload it. - st.markdown(f"Using model at {model_path}.") + st.markdown(f"Loading model at {model_path}.") + st.markdown("Please wait a moment...") release_model() model_manager, pipeline = load_model(model_type, model_path) + st.markdown("Done.") else: # The loaded model is not the selected model. Fetch it from `st.session_state`. - st.markdown(f"Using model at {model_path}.") + st.markdown(f"Loading model at {model_path}.") + st.markdown("Please wait a moment...") model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline + st.markdown("Done.") # Show parameters with st.expander("Prompt", expanded=True):