mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 06:46:13 +00:00
sdxl modelcode
This commit is contained in:
@@ -906,6 +906,14 @@ stable_diffusion_xl_series = [
|
||||
"model_hash": "142b114f67f5ab3a6d83fb5788f12ded",
|
||||
"model_name": "stable_diffusion_xl_unet",
|
||||
"model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel",
|
||||
"extra_kwargs": {
|
||||
"attention_head_dim": [5, 10, 20],
|
||||
"transformer_layers_per_block": [1, 2, 10],
|
||||
"use_linear_projection": True,
|
||||
"addition_embed_type": "text_time",
|
||||
"addition_time_embed_dim": 256,
|
||||
"projection_class_embeddings_input_dim": 2816,
|
||||
},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
import torch
|
||||
|
||||
def SDXLTextEncoder2StateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
for key in state_dict:
|
||||
if key.startswith("text_model.") and "position_ids" not in key:
|
||||
if key == "text_projection.weight":
|
||||
val = state_dict[key]
|
||||
new_state_dict["model.text_projection.weight"] = val.float() if val.dtype == torch.float16 else val
|
||||
elif key.startswith("text_model.") and "position_ids" not in key:
|
||||
new_key = "model." + key
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
val = state_dict[key]
|
||||
new_state_dict[new_key] = val.float() if val.dtype == torch.float16 else val
|
||||
return new_state_dict
|
||||
|
||||
Reference in New Issue
Block a user