add low vram examples

This commit is contained in:
Artiprocher
2025-08-15 11:31:57 +08:00
parent 0b574cc0c2
commit e1c2eda5f5
12 changed files with 269 additions and 36 deletions

View File

@@ -193,6 +193,23 @@ class QwenImagePipeline(BasePipeline):
),
vram_limit=vram_limit,
)
if self.blockwise_controlnet is not None:
enable_vram_management(
self.blockwise_controlnet,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
@staticmethod
@@ -393,7 +410,7 @@ class QwenImageUnit_EntityControl(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
onload_model_names=("text_encoder")
onload_model_names=("text_encoder",)
)
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):