fix qwen_text_encoder bug in transformers>=5.2.0

This commit is contained in:
mi804
2026-03-03 12:44:36 +08:00
parent b3ef224042
commit 62ba8a3f2e
2 changed files with 17 additions and 2 deletions

View File

@@ -250,3 +250,17 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
}
def QwenImageTextEncoder_Module_Map_Updater():
current = VRAM_MANAGEMENT_MODULE_MAPS["diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder"]
from packaging import version
import transformers
if version.parse(transformers.__version__) >= version.parse("5.2.0"):
# The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly
current.pop("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm", None)
current["transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm"] = "diffsynth.core.vram.layers.AutoWrappedModule"
return current
VERSION_CHECKER_MAPS = {
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": QwenImageTextEncoder_Module_Map_Updater,
}