sdxl modelcode

This commit is contained in:
mi804
2026-04-23 18:26:25 +08:00
parent 82e482286c
commit 9453700a30
2 changed files with 16 additions and 2 deletions

View File

@@ -906,6 +906,14 @@ stable_diffusion_xl_series = [
"model_hash": "142b114f67f5ab3a6d83fb5788f12ded", "model_hash": "142b114f67f5ab3a6d83fb5788f12ded",
"model_name": "stable_diffusion_xl_unet", "model_name": "stable_diffusion_xl_unet",
"model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel", "model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel",
"extra_kwargs": {
"attention_head_dim": [5, 10, 20],
"transformer_layers_per_block": [1, 2, 10],
"use_linear_projection": True,
"addition_embed_type": "text_time",
"addition_time_embed_dim": 256,
"projection_class_embeddings_input_dim": 2816,
},
}, },
{ {
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors") # Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")

View File

@@ -1,7 +1,13 @@
import torch
def SDXLTextEncoder2StateDictConverter(state_dict): def SDXLTextEncoder2StateDictConverter(state_dict):
new_state_dict = {} new_state_dict = {}
for key in state_dict: for key in state_dict:
if key.startswith("text_model.") and "position_ids" not in key: if key == "text_projection.weight":
val = state_dict[key]
new_state_dict["model.text_projection.weight"] = val.float() if val.dtype == torch.float16 else val
elif key.startswith("text_model.") and "position_ids" not in key:
new_key = "model." + key new_key = "model." + key
new_state_dict[new_key] = state_dict[key] val = state_dict[key]
new_state_dict[new_key] = val.float() if val.dtype == torch.float16 else val
return new_state_dict return new_state_dict