From 9453700a300cbfc8606f33372b9e15feb1f23fd9 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 23 Apr 2026 18:26:25 +0800 Subject: [PATCH] sdxl modelcode --- diffsynth/configs/model_configs.py | 8 ++++++++ .../stable_diffusion_xl_text_encoder.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index f017aaa..299d9d5 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -906,6 +906,14 @@ stable_diffusion_xl_series = [ "model_hash": "142b114f67f5ab3a6d83fb5788f12ded", "model_name": "stable_diffusion_xl_unet", "model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel", + "extra_kwargs": { + "attention_head_dim": [5, 10, 20], + "transformer_layers_per_block": [1, 2, 10], + "use_linear_projection": True, + "addition_embed_type": "text_time", + "addition_time_embed_dim": 256, + "projection_class_embeddings_input_dim": 2816, + }, }, { # Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors") diff --git a/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py b/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py index 789decb..ec3ae28 100644 --- a/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py +++ b/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py @@ -1,7 +1,13 @@ +import torch + def SDXLTextEncoder2StateDictConverter(state_dict): new_state_dict = {} for key in state_dict: - if key.startswith("text_model.") and "position_ids" not in key: + if key == "text_projection.weight": + val = state_dict[key] + new_state_dict["model.text_projection.weight"] = val.float() if val.dtype == torch.float16 else val + elif key.startswith("text_model.") and "position_ids" not in key: new_key = "model." + key - new_state_dict[new_key] = state_dict[key] + val = state_dict[key] + new_state_dict[new_key] = val.float() if val.dtype == torch.float16 else val return new_state_dict