diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index f017aaa..299d9d5 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -906,6 +906,14 @@ stable_diffusion_xl_series = [ "model_hash": "142b114f67f5ab3a6d83fb5788f12ded", "model_name": "stable_diffusion_xl_unet", "model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel", + "extra_kwargs": { + "attention_head_dim": [5, 10, 20], + "transformer_layers_per_block": [1, 2, 10], + "use_linear_projection": True, + "addition_embed_type": "text_time", + "addition_time_embed_dim": 256, + "projection_class_embeddings_input_dim": 2816, + }, }, { # Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors") diff --git a/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py b/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py index 789decb..ec3ae28 100644 --- a/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py +++ b/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py @@ -1,7 +1,13 @@ +import torch + def SDXLTextEncoder2StateDictConverter(state_dict): new_state_dict = {} for key in state_dict: - if key.startswith("text_model.") and "position_ids" not in key: + if key == "text_projection.weight": + val = state_dict[key] + new_state_dict["model.text_projection.weight"] = val.float() if val.dtype == torch.float16 else val + elif key.startswith("text_model.") and "position_ids" not in key: new_key = "model." + key - new_state_dict[new_key] = state_dict[key] + val = state_dict[key] + new_state_dict[new_key] = val.float() if val.dtype == torch.float16 else val return new_state_dict