mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
Hunyuan DiT UI
This commit is contained in:
@@ -18,3 +18,4 @@ dependencies:
|
|||||||
- imageio[ffmpeg]
|
- imageio[ffmpeg]
|
||||||
- safetensors
|
- safetensors
|
||||||
- einops
|
- einops
|
||||||
|
- sentencepiece
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import streamlit as st
|
|||||||
st.set_page_config(layout="wide")
|
st.set_page_config(layout="wide")
|
||||||
from streamlit_drawable_canvas import st_canvas
|
from streamlit_drawable_canvas import st_canvas
|
||||||
from diffsynth.models import ModelManager
|
from diffsynth.models import ModelManager
|
||||||
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline
|
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, HunyuanDiTImagePipeline
|
||||||
from diffsynth.data.video import crop_and_resize
|
from diffsynth.data.video import crop_and_resize
|
||||||
|
|
||||||
|
|
||||||
@@ -30,14 +30,23 @@ config = {
|
|||||||
"height": 512,
|
"height": 512,
|
||||||
"width": 512,
|
"width": 512,
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"HunyuanDiT": {
|
||||||
|
"model_folder": "models/HunyuanDiT",
|
||||||
|
"pipeline_class": HunyuanDiTImagePipeline,
|
||||||
|
"fixed_parameters": {
|
||||||
|
"height": 1024,
|
||||||
|
"width": 1024,
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_model_list(model_type):
|
def load_model_list(model_type):
|
||||||
folder = config[model_type]["model_folder"]
|
folder = config[model_type]["model_folder"]
|
||||||
file_list = os.listdir(folder)
|
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
||||||
file_list = [i for i in file_list if i.endswith(".safetensors")]
|
if model_type == "HunyuanDiT":
|
||||||
|
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
||||||
file_list = sorted(file_list)
|
file_list = sorted(file_list)
|
||||||
return file_list
|
return file_list
|
||||||
|
|
||||||
@@ -53,7 +62,15 @@ def release_model():
|
|||||||
|
|
||||||
def load_model(model_type, model_path):
|
def load_model(model_type, model_path):
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.load_model(model_path)
|
if model_type == "HunyuanDiT":
|
||||||
|
model_manager.load_models([
|
||||||
|
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
|
||||||
|
os.path.join(model_path, "mt5/pytorch_model.bin"),
|
||||||
|
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
||||||
|
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
model_manager.load_model(model_path)
|
||||||
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
|
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||||
st.session_state.loaded_model_path = model_path
|
st.session_state.loaded_model_path = model_path
|
||||||
st.session_state.model_manager = model_manager
|
st.session_state.model_manager = model_manager
|
||||||
@@ -109,7 +126,7 @@ column_input, column_output = st.columns(2)
|
|||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
# Select a model
|
# Select a model
|
||||||
with st.expander("Model", expanded=True):
|
with st.expander("Model", expanded=True):
|
||||||
model_type = st.selectbox("Model type", ["Stable Diffusion", "Stable Diffusion XL", "Stable Diffusion XL Turbo"])
|
model_type = st.selectbox("Model type", [model_type_ for model_type_ in config])
|
||||||
fixed_parameters = config[model_type]["fixed_parameters"]
|
fixed_parameters = config[model_type]["fixed_parameters"]
|
||||||
model_path_list = ["None"] + load_model_list(model_type)
|
model_path_list = ["None"] + load_model_list(model_type)
|
||||||
model_path = st.selectbox("Model path", model_path_list)
|
model_path = st.selectbox("Model path", model_path_list)
|
||||||
@@ -124,13 +141,17 @@ with st.sidebar:
|
|||||||
model_path = os.path.join(config[model_type]["model_folder"], model_path)
|
model_path = os.path.join(config[model_type]["model_folder"], model_path)
|
||||||
if st.session_state.get("loaded_model_path", "") != model_path:
|
if st.session_state.get("loaded_model_path", "") != model_path:
|
||||||
# The loaded model is not the selected model. Reload it.
|
# The loaded model is not the selected model. Reload it.
|
||||||
st.markdown(f"Using model at {model_path}.")
|
st.markdown(f"Loading model at {model_path}.")
|
||||||
|
st.markdown("Please wait a moment...")
|
||||||
release_model()
|
release_model()
|
||||||
model_manager, pipeline = load_model(model_type, model_path)
|
model_manager, pipeline = load_model(model_type, model_path)
|
||||||
|
st.markdown("Done.")
|
||||||
else:
|
else:
|
||||||
# The loaded model is not the selected model. Fetch it from `st.session_state`.
|
# The loaded model is not the selected model. Fetch it from `st.session_state`.
|
||||||
st.markdown(f"Using model at {model_path}.")
|
st.markdown(f"Loading model at {model_path}.")
|
||||||
|
st.markdown("Please wait a moment...")
|
||||||
model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
|
model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
|
||||||
|
st.markdown("Done.")
|
||||||
|
|
||||||
# Show parameters
|
# Show parameters
|
||||||
with st.expander("Prompt", expanded=True):
|
with st.expander("Prompt", expanded=True):
|
||||||
|
|||||||
Reference in New Issue
Block a user