mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
sdxl modelcode
This commit is contained in:
@@ -906,6 +906,14 @@ stable_diffusion_xl_series = [
|
|||||||
"model_hash": "142b114f67f5ab3a6d83fb5788f12ded",
|
"model_hash": "142b114f67f5ab3a6d83fb5788f12ded",
|
||||||
"model_name": "stable_diffusion_xl_unet",
|
"model_name": "stable_diffusion_xl_unet",
|
||||||
"model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel",
|
"model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel",
|
||||||
|
"extra_kwargs": {
|
||||||
|
"attention_head_dim": [5, 10, 20],
|
||||||
|
"transformer_layers_per_block": [1, 2, 10],
|
||||||
|
"use_linear_projection": True,
|
||||||
|
"addition_embed_type": "text_time",
|
||||||
|
"addition_time_embed_dim": 256,
|
||||||
|
"projection_class_embeddings_input_dim": 2816,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")
|
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")
|
||||||
|
|||||||
@@ -1,7 +1,13 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
def SDXLTextEncoder2StateDictConverter(state_dict):
|
def SDXLTextEncoder2StateDictConverter(state_dict):
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
for key in state_dict:
|
for key in state_dict:
|
||||||
if key.startswith("text_model.") and "position_ids" not in key:
|
if key == "text_projection.weight":
|
||||||
|
val = state_dict[key]
|
||||||
|
new_state_dict["model.text_projection.weight"] = val.float() if val.dtype == torch.float16 else val
|
||||||
|
elif key.startswith("text_model.") and "position_ids" not in key:
|
||||||
new_key = "model." + key
|
new_key = "model." + key
|
||||||
new_state_dict[new_key] = state_dict[key]
|
val = state_dict[key]
|
||||||
|
new_state_dict[new_key] = val.float() if val.dtype == torch.float16 else val
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|||||||
Reference in New Issue
Block a user