update vram management strategy (#929)

This commit is contained in:
Zhongjie Duan
2025-09-18 16:53:13 +08:00
committed by GitHub

View File

@@ -194,11 +194,12 @@ class QwenImagePipeline(BasePipeline):
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, auto_offload=True, enable_dit_fp8_computation=False):
self.vram_management_enabled = True
if vram_limit is None:
if vram_limit is None and auto_offload:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if vram_limit is not None:
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding