support vram management in flux

This commit is contained in:
Artiprocher
2025-02-13 15:11:39 +08:00
parent 46d4616e23
commit 0699212665
8 changed files with 246 additions and 6 deletions

View File

@@ -80,7 +80,10 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
else:
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
if torch_dtype == torch.float16 and hasattr(model, "half"):
model = model.half()
try:

View File

@@ -9,7 +9,8 @@ class SD3TextEncoder1(SDTextEncoder):
super().__init__(vocab_size=vocab_size)
def forward(self, input_ids, clip_skip=2, extra_mask=None):
embeds = self.token_embedding(input_ids) + self.position_embeds
embeds = self.token_embedding(input_ids)
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
if extra_mask is not None:
attn_mask[:, extra_mask[0]==0] = float("-inf")