Hunyuan DiT UI

This commit is contained in:
Artiprocher
2024-06-06 15:56:24 +08:00
parent 53735151fa
commit f6de5eef4d
2 changed files with 30 additions and 8 deletions

View File

@@ -18,3 +18,4 @@ dependencies:
- imageio[ffmpeg]
- safetensors
- einops
- sentencepiece

View File

@@ -5,7 +5,7 @@ import streamlit as st
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from diffsynth.models import ModelManager
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, HunyuanDiTImagePipeline
from diffsynth.data.video import crop_and_resize
@@ -30,14 +30,23 @@ config = {
"height": 512,
"width": 512,
}
}
},
"HunyuanDiT": {
"model_folder": "models/HunyuanDiT",
"pipeline_class": HunyuanDiTImagePipeline,
"fixed_parameters": {
"height": 1024,
"width": 1024,
}
},
}
def load_model_list(model_type):
folder = config[model_type]["model_folder"]
file_list = os.listdir(folder)
file_list = [i for i in file_list if i.endswith(".safetensors")]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type == "HunyuanDiT":
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list)
return file_list
@@ -53,7 +62,15 @@ def release_model():
def load_model(model_type, model_path):
model_manager = ModelManager()
model_manager.load_model(model_path)
if model_type == "HunyuanDiT":
model_manager.load_models([
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
os.path.join(model_path, "mt5/pytorch_model.bin"),
os.path.join(model_path, "model/pytorch_model_ema.pt"),
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
])
else:
model_manager.load_model(model_path)
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
st.session_state.loaded_model_path = model_path
st.session_state.model_manager = model_manager
@@ -109,7 +126,7 @@ column_input, column_output = st.columns(2)
with st.sidebar:
# Select a model
with st.expander("Model", expanded=True):
model_type = st.selectbox("Model type", ["Stable Diffusion", "Stable Diffusion XL", "Stable Diffusion XL Turbo"])
model_type = st.selectbox("Model type", [model_type_ for model_type_ in config])
fixed_parameters = config[model_type]["fixed_parameters"]
model_path_list = ["None"] + load_model_list(model_type)
model_path = st.selectbox("Model path", model_path_list)
@@ -124,13 +141,17 @@ with st.sidebar:
model_path = os.path.join(config[model_type]["model_folder"], model_path)
if st.session_state.get("loaded_model_path", "") != model_path:
# The loaded model is not the selected model. Reload it.
st.markdown(f"Using model at {model_path}.")
st.markdown(f"Loading model at {model_path}.")
st.markdown("Please wait a moment...")
release_model()
model_manager, pipeline = load_model(model_type, model_path)
st.markdown("Done.")
else:
# The loaded model is not the selected model. Fetch it from `st.session_state`.
st.markdown(f"Using model at {model_path}.")
st.markdown(f"Loading model at {model_path}.")
st.markdown("Please wait a moment...")
model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
st.markdown("Done.")
# Show parameters
with st.expander("Prompt", expanded=True):