diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index d9da07c..d458aa6 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -103,7 +103,7 @@ class QwenImagePipeline(BasePipeline): vram_limit = vram_limit - vram_buffer if self.text_encoder is not None: - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding dtype = next(iter(self.text_encoder.parameters())).dtype enable_vram_management( self.text_encoder, @@ -112,6 +112,8 @@ class QwenImagePipeline(BasePipeline): torch.nn.Embedding: AutoWrappedModule, Qwen2_5_VLRotaryEmbedding: AutoWrappedModule, Qwen2RMSNorm: AutoWrappedModule, + Qwen2_5_VisionPatchEmbed: AutoWrappedModule, + Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype,