Merge pull request #1321 from mi804/bugfix

fix qwen_text_encoder bug in transformers>=5.2.0
This commit is contained in:
Zhongjie Duan
2026-03-03 13:02:45 +08:00
committed by GitHub
2 changed files with 17 additions and 2 deletions

View File

@@ -250,3 +250,17 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
}
def QwenImageTextEncoder_Module_Map_Updater():
current = VRAM_MANAGEMENT_MODULE_MAPS["diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder"]
from packaging import version
import transformers
if version.parse(transformers.__version__) >= version.parse("5.2.0"):
# The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly
current.pop("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm", None)
current["transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm"] = "diffsynth.core.vram.layers.AutoWrappedModule"
return current
VERSION_CHECKER_MAPS = {
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": QwenImageTextEncoder_Module_Map_Updater,
}

View File

@@ -1,6 +1,6 @@
from ..core.loader import load_model, hash_model_file
from ..core.vram import AutoWrappedModule
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS
import importlib, json, torch
@@ -22,7 +22,8 @@ class ModelPool:
def fetch_module_map(self, model_class, vram_config):
if self.need_to_enable_vram_management(vram_config):
if model_class in VRAM_MANAGEMENT_MODULE_MAPS:
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()}
vram_module_map = VRAM_MANAGEMENT_MODULE_MAPS[model_class] if model_class not in VERSION_CHECKER_MAPS else VERSION_CHECKER_MAPS[model_class]()
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in vram_module_map.items()}
else:
module_map = {self.import_model_class(model_class): AutoWrappedModule}
else: