mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
support vram management in flux
This commit is contained in:
@@ -9,7 +9,8 @@ class SD3TextEncoder1(SDTextEncoder):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
|
||||
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
embeds = self.token_embedding(input_ids)
|
||||
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
if extra_mask is not None:
|
||||
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
||||
|
||||
Reference in New Issue
Block a user