support kolors in webui

This commit is contained in:
Artiprocher
2024-07-29 16:24:13 +08:00
parent 05c97bc755
commit 8680f92b60
2 changed files with 12 additions and 1 deletions

View File

@@ -36,6 +36,11 @@ config = {
"width": 512, "width": 512,
} }
}, },
"Kolors": {
"model_folder": "models/kolors",
"pipeline_class": SDXLImagePipeline,
"fixed_parameters": {}
},
"HunyuanDiT": { "HunyuanDiT": {
"model_folder": "models/HunyuanDiT", "model_folder": "models/HunyuanDiT",
"pipeline_class": HunyuanDiTImagePipeline, "pipeline_class": HunyuanDiTImagePipeline,
@@ -50,7 +55,7 @@ config = {
def load_model_list(model_type): def load_model_list(model_type):
folder = config[model_type]["model_folder"] folder = config[model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type == "HunyuanDiT": if model_type in ["HunyuanDiT", "Kolors"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list) file_list = sorted(file_list)
return file_list return file_list
@@ -74,6 +79,12 @@ def load_model(model_type, model_path):
os.path.join(model_path, "model/pytorch_model_ema.pt"), os.path.join(model_path, "model/pytorch_model_ema.pt"),
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"), os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
]) ])
elif model_type == "Kolors":
model_manager.load_models([
os.path.join(model_path, "text_encoder"),
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
])
else: else:
model_manager.load_model(model_path) model_manager.load_model(model_path)
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager) pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)