From a6aaf9da2a7eee0d8cef908bc6f0ff8f8b22b142 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 19 Aug 2024 14:24:23 +0800 Subject: [PATCH] support flux UI --- .../Put Stable Diffusion checkpoints here.txt | 0 pages/1_Image_Creator.py | 21 +++++++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) create mode 100644 models/FLUX/Put Stable Diffusion checkpoints here.txt diff --git a/models/FLUX/Put Stable Diffusion checkpoints here.txt b/models/FLUX/Put Stable Diffusion checkpoints here.txt new file mode 100644 index 0000000..e69de29 diff --git a/pages/1_Image_Creator.py b/pages/1_Image_Creator.py index 2d13782..3b8ad45 100644 --- a/pages/1_Image_Creator.py +++ b/pages/1_Image_Creator.py @@ -5,7 +5,7 @@ import streamlit as st st.set_page_config(layout="wide") from streamlit_drawable_canvas import st_canvas from diffsynth.models import ModelManager -from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline +from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline from diffsynth.data.video import crop_and_resize @@ -49,13 +49,20 @@ config = { "width": 1024, } }, + "FLUX": { + "model_folder": "models/FLUX", + "pipeline_class": FluxImagePipeline, + "fixed_parameters": { + "cfg_scale": 1.0, + } + } } def load_model_list(model_type): folder = config[model_type]["model_folder"] file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] - if model_type in ["HunyuanDiT", "Kolors"]: + if model_type in ["HunyuanDiT", "Kolors", "FLUX"]: file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] file_list = sorted(file_list) return file_list @@ -85,6 +92,16 @@ def load_model(model_type, model_path): os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"), os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"), ]) + elif model_type == "FLUX": + model_manager.torch_dtype = torch.bfloat16 + file_list = [ + os.path.join(model_path, "text_encoder/model.safetensors"), + os.path.join(model_path, "text_encoder_2"), + ] + for file_name in os.listdir(model_path): + if file_name.endswith(".safetensors"): + file_list.append(os.path.join(model_path, file_name)) + model_manager.load_models(file_list) else: model_manager.load_model(model_path) pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)