update flux pipeline

This commit is contained in:
Artiprocher
2024-10-10 17:05:04 +08:00
parent 41ea2f811a
commit fa0fa95bb6
6 changed files with 32 additions and 22 deletions

View File

@@ -65,9 +65,11 @@ class BasePipeline(torch.nn.Module):
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self):
self.cpu_offload = True
def load_models_to_device(self, loadmodel_names=[]):
# only load models to device if cpu_offload is enabled
if not self.cpu_offload:
@@ -85,3 +87,9 @@ class BasePipeline(torch.nn.Module):
model.to(self.device)
# fresh the cuda cache
torch.cuda.empty_cache()
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return noise