From 62ba8a3f2e129ae4d31c36e766a10d5fd13ca092 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 3 Mar 2026 12:44:36 +0800 Subject: [PATCH] fix qwen_text_encoder bug in transformers>=5.2.0 --- diffsynth/configs/vram_management_module_maps.py | 14 ++++++++++++++ diffsynth/models/model_loader.py | 5 +++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index d86f5fa..902c38b 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -250,3 +250,17 @@ VRAM_MANAGEMENT_MODULE_MAPS = { "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", }, } + +def QwenImageTextEncoder_Module_Map_Updater(): + current = VRAM_MANAGEMENT_MODULE_MAPS["diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder"] + from packaging import version + import transformers + if version.parse(transformers.__version__) >= version.parse("5.2.0"): + # The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly + current.pop("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm", None) + current["transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm"] = "diffsynth.core.vram.layers.AutoWrappedModule" + return current + +VERSION_CHECKER_MAPS = { + "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": QwenImageTextEncoder_Module_Map_Updater, +} \ No newline at end of file diff --git a/diffsynth/models/model_loader.py b/diffsynth/models/model_loader.py index 6a58c89..7a716e2 100644 --- a/diffsynth/models/model_loader.py +++ b/diffsynth/models/model_loader.py @@ -1,6 +1,6 @@ from ..core.loader import load_model, hash_model_file from ..core.vram import AutoWrappedModule -from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS +from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS import importlib, json, torch @@ -22,7 +22,8 @@ class ModelPool: def fetch_module_map(self, model_class, vram_config): if self.need_to_enable_vram_management(vram_config): if model_class in VRAM_MANAGEMENT_MODULE_MAPS: - module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()} + vram_module_map = VRAM_MANAGEMENT_MODULE_MAPS[model_class] if model_class not in VERSION_CHECKER_MAPS else VERSION_CHECKER_MAPS[model_class]() + module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in vram_module_map.items()} else: module_map = {self.import_model_class(model_class): AutoWrappedModule} else: