sdxl modelcode

This commit is contained in:
mi804
2026-04-23 18:26:25 +08:00
parent 82e482286c
commit 9453700a30
2 changed files with 16 additions and 2 deletions

View File

@@ -1,7 +1,13 @@
import torch
def SDXLTextEncoder2StateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
if key.startswith("text_model.") and "position_ids" not in key:
if key == "text_projection.weight":
val = state_dict[key]
new_state_dict["model.text_projection.weight"] = val.float() if val.dtype == torch.float16 else val
elif key.startswith("text_model.") and "position_ids" not in key:
new_key = "model." + key
new_state_dict[new_key] = state_dict[key]
val = state_dict[key]
new_state_dict[new_key] = val.float() if val.dtype == torch.float16 else val
return new_state_dict