support vram management in flux

This commit is contained in:
Artiprocher
2025-02-13 15:11:39 +08:00
parent 46d4616e23
commit 0699212665
8 changed files with 246 additions and 6 deletions

View File

@@ -9,7 +9,8 @@ class SD3TextEncoder1(SDTextEncoder):
super().__init__(vocab_size=vocab_size)
def forward(self, input_ids, clip_skip=2, extra_mask=None):
embeds = self.token_embedding(input_ids) + self.position_embeds
embeds = self.token_embedding(input_ids)
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
if extra_mask is not None:
attn_mask[:, extra_mask[0]==0] = float("-inf")