support flux UI

This commit is contained in:
Artiprocher
2024-08-19 14:24:23 +08:00
parent aa908ae0c2
commit a6aaf9da2a
2 changed files with 19 additions and 2 deletions

View File

@@ -5,7 +5,7 @@ import streamlit as st
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from diffsynth.models import ModelManager
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
from diffsynth.data.video import crop_and_resize
@@ -49,13 +49,20 @@ config = {
"width": 1024,
}
},
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"fixed_parameters": {
"cfg_scale": 1.0,
}
}
}
def load_model_list(model_type):
folder = config[model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type in ["HunyuanDiT", "Kolors"]:
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list)
return file_list
@@ -85,6 +92,16 @@ def load_model(model_type, model_path):
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
])
elif model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16
file_list = [
os.path.join(model_path, "text_encoder/model.safetensors"),
os.path.join(model_path, "text_encoder_2"),
]
for file_name in os.listdir(model_path):
if file_name.endswith(".safetensors"):
file_list.append(os.path.join(model_path, file_name))
model_manager.load_models(file_list)
else:
model_manager.load_model(model_path)
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)