mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
support flux highresfix
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
|
||||
from einops import rearrange
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
|
||||
@@ -306,9 +307,62 @@ class FluxDiT(torch.nn.Module):
|
||||
def unpatchify(self, hidden_states, height, width):
|
||||
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def prepare_image_ids(self, latents):
|
||||
batch_size, _, height, width = latents.shape
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
return latent_image_ids
|
||||
|
||||
|
||||
def tiled_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
||||
tile_size=128, tile_stride=64,
|
||||
**kwargs
|
||||
):
|
||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
|
||||
hidden_states,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids, **kwargs):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||
tiled=False, tile_size=128, tile_stride=64,
|
||||
**kwargs
|
||||
):
|
||||
if tiled:
|
||||
return self.tiled_forward(
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
||||
tile_size=tile_size, tile_stride=tile_stride,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = self.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
|
||||
+ self.guidance_embedder(guidance, hidden_states.dtype)\
|
||||
+ self.pooled_text_embedder(pooled_prompt_emb)
|
||||
|
||||
@@ -64,20 +64,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None, guidance=0.0):
|
||||
batch_size, _, height, width = latents.shape
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
guidance = torch.Tensor([guidance] * batch_size).to(device=latents.device, dtype=latents.dtype)
|
||||
latent_image_ids = self.dit.prepare_image_ids(latents)
|
||||
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
|
||||
return {"image_ids": latent_image_ids, "guidance": guidance}
|
||||
|
||||
|
||||
@@ -88,7 +76,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
local_prompts=[],
|
||||
masks=[],
|
||||
mask_scales=[],
|
||||
cfg_scale=0.0,
|
||||
negative_prompt="",
|
||||
cfg_scale=1.0,
|
||||
embedded_guidance=0.0,
|
||||
input_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
@@ -116,23 +106,32 @@ class FluxImagePipeline(BasePipeline):
|
||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb = self.encode_prompt(prompt, positive=True)
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=cfg_scale)
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Inference (FLUX doesn't support classifier-free guidance)
|
||||
inference_callback = lambda prompt_emb: self.dit(
|
||||
latents, timestep=timestep, **prompt_emb, **tiler_kwargs, **extra_input
|
||||
# Classifier-free guidance
|
||||
inference_callback = lambda prompt_emb_posi: self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
|
||||
)
|
||||
noise_pred = self.control_noise_via_local_prompts(prompt_emb, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# DDIM
|
||||
# Iterate
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# UI
|
||||
|
||||
Reference in New Issue
Block a user