update flux pipeline

This commit is contained in:
Artiprocher
2024-10-10 17:05:04 +08:00
parent 41ea2f811a
commit fa0fa95bb6
6 changed files with 32 additions and 22 deletions

View File

@@ -65,9 +65,11 @@ class BasePipeline(torch.nn.Module):
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self):
self.cpu_offload = True
def load_models_to_device(self, loadmodel_names=[]):
# only load models to device if cpu_offload is enabled
if not self.cpu_offload:
@@ -85,3 +87,9 @@ class BasePipeline(torch.nn.Module):
model.to(self.device)
# fresh the cuda cache
torch.cuda.empty_cache()
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return noise

View File

@@ -58,7 +58,7 @@ class FluxImagePipeline(BasePipeline):
return image
def encode_prompt(self, prompt, positive=True, t5_sequence_length=256):
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
)
@@ -80,7 +80,7 @@ class FluxImagePipeline(BasePipeline):
mask_scales= None,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=1.0,
embedded_guidance=3.5,
input_image=None,
denoising_strength=1.0,
height=1024,
@@ -90,6 +90,7 @@ class FluxImagePipeline(BasePipeline):
tiled=False,
tile_size=128,
tile_stride=64,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
@@ -104,10 +105,10 @@ class FluxImagePipeline(BasePipeline):
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])