support flux highresfix

This commit is contained in:
Artiprocher
2024-08-19 13:35:27 +08:00
parent 80aa4d8e19
commit 778a2d8f84
7 changed files with 111 additions and 25 deletions

View File

@@ -64,20 +64,8 @@ class FluxImagePipeline(BasePipeline):
def prepare_extra_input(self, latents=None, guidance=0.0):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
guidance = torch.Tensor([guidance] * batch_size).to(device=latents.device, dtype=latents.dtype)
latent_image_ids = self.dit.prepare_image_ids(latents)
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
return {"image_ids": latent_image_ids, "guidance": guidance}
@@ -88,7 +76,9 @@ class FluxImagePipeline(BasePipeline):
local_prompts=[],
masks=[],
mask_scales=[],
cfg_scale=0.0,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=0.0,
input_image=None,
denoising_strength=1.0,
height=1024,
@@ -116,23 +106,32 @@ class FluxImagePipeline(BasePipeline):
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
prompt_emb = self.encode_prompt(prompt, positive=True)
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=cfg_scale)
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Inference (FLUX doesn't support classifier-free guidance)
inference_callback = lambda prompt_emb: self.dit(
latents, timestep=timestep, **prompt_emb, **tiler_kwargs, **extra_input
# Classifier-free guidance
inference_callback = lambda prompt_emb_posi: self.dit(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
)
noise_pred = self.control_noise_via_local_prompts(prompt_emb, prompt_emb_locals, masks, mask_scales, inference_callback)
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM
# Iterate
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# UI