diff --git a/models/Kolors/Put Kolors checkpoints here.txt b/models/kolors/Put Kolors checkpoints here.txt similarity index 100% rename from models/Kolors/Put Kolors checkpoints here.txt rename to models/kolors/Put Kolors checkpoints here.txt diff --git a/pages/1_Image_Creator.py b/pages/1_Image_Creator.py index d5a26de..9fb49ca 100644 --- a/pages/1_Image_Creator.py +++ b/pages/1_Image_Creator.py @@ -36,6 +36,11 @@ config = { "width": 512, } }, + "Kolors": { + "model_folder": "models/kolors", + "pipeline_class": SDXLImagePipeline, + "fixed_parameters": {} + }, "HunyuanDiT": { "model_folder": "models/HunyuanDiT", "pipeline_class": HunyuanDiTImagePipeline, @@ -50,7 +55,7 @@ config = { def load_model_list(model_type): folder = config[model_type]["model_folder"] file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] - if model_type == "HunyuanDiT": + if model_type in ["HunyuanDiT", "Kolors"]: file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] file_list = sorted(file_list) return file_list @@ -74,6 +79,12 @@ def load_model(model_type, model_path): os.path.join(model_path, "model/pytorch_model_ema.pt"), os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"), ]) + elif model_type == "Kolors": + model_manager.load_models([ + os.path.join(model_path, "text_encoder"), + os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"), + os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"), + ]) else: model_manager.load_model(model_path) pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)