support flux highresfix

This commit is contained in:
Artiprocher
2024-08-19 13:35:27 +08:00
parent 80aa4d8e19
commit 778a2d8f84
7 changed files with 111 additions and 25 deletions

View File

@@ -1,6 +1,7 @@
import torch
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
from einops import rearrange
from .tiler import TileWorker
@@ -306,9 +307,62 @@ class FluxDiT(torch.nn.Module):
def unpatchify(self, hidden_states, height, width):
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def tiled_forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
tile_size=128, tile_stride=64,
**kwargs
):
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
hidden_states,
tile_size,
tile_stride,
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
return hidden_states
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids, **kwargs):
def forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
tiled=False, tile_size=128, tile_stride=64,
**kwargs
):
if tiled:
return self.tiled_forward(
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
tile_size=tile_size, tile_stride=tile_stride,
**kwargs
)
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
+ self.guidance_embedder(guidance, hidden_states.dtype)\
+ self.pooled_text_embedder(pooled_prompt_emb)

View File

@@ -64,20 +64,8 @@ class FluxImagePipeline(BasePipeline):
def prepare_extra_input(self, latents=None, guidance=0.0):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
guidance = torch.Tensor([guidance] * batch_size).to(device=latents.device, dtype=latents.dtype)
latent_image_ids = self.dit.prepare_image_ids(latents)
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
return {"image_ids": latent_image_ids, "guidance": guidance}
@@ -88,7 +76,9 @@ class FluxImagePipeline(BasePipeline):
local_prompts=[],
masks=[],
mask_scales=[],
cfg_scale=0.0,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=0.0,
input_image=None,
denoising_strength=1.0,
height=1024,
@@ -116,23 +106,32 @@ class FluxImagePipeline(BasePipeline):
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
prompt_emb = self.encode_prompt(prompt, positive=True)
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=cfg_scale)
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Inference (FLUX doesn't support classifier-free guidance)
inference_callback = lambda prompt_emb: self.dit(
latents, timestep=timestep, **prompt_emb, **tiler_kwargs, **extra_input
# Classifier-free guidance
inference_callback = lambda prompt_emb_posi: self.dit(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
)
noise_pred = self.control_noise_via_local_prompts(prompt_emb, prompt_emb_locals, masks, mask_scales, inference_callback)
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM
# Iterate
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# UI

View File

@@ -2,10 +2,20 @@
Image synthesis is the base feature of DiffSynth Studio. We can generate images with very high resolution.
### Example: FLUX
Example script: [`flux_text_to_image.py`](./flux_text_to_image.py)
|1024*1024 (original)|1024*1024 (classifier-free guidance)|2048*2048 (highres-fix)|
|-|-|-|
|![image_1024](https://github.com/user-attachments/assets/d8e66872-8739-43e4-8c2b-eda9daba0450)|![image_1024_cfg](https://github.com/user-attachments/assets/1073c70d-018f-47e4-9342-bc580b4c7c59)|![image_2048_highres](https://github.com/user-attachments/assets/8719c1a8-b341-48c1-a085-364c3a7d25f0)|
### Example: Stable Diffusion
Example script: [`sd_text_to_image.py`](./sd_text_to_image.py)
LoRA Training: [`../train/stable_diffusion/`](../train/stable_diffusion/)
|512*512|1024*1024|2048*2048|4096*4096|
|-|-|-|-|
|![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)|
@@ -14,6 +24,8 @@ Example script: [`sd_text_to_image.py`](./sd_text_to_image.py)
Example script: [`sdxl_text_to_image.py`](./sdxl_text_to_image.py)
LoRA Training: [`../train/stable_diffusion_xl/`](../train/stable_diffusion_xl/)
|1024*1024|2048*2048|
|-|-|
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|

View File

@@ -12,9 +12,30 @@ model_manager.load_models([
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
prompt = "A captivating fantasy magic woman portrait set in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
torch.manual_seed(6)
image = pipe(
"A captivating fantasy magic woman portrait set in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin.",
num_inference_steps=30
prompt=prompt,
num_inference_steps=30,
)
image.save("image_1024.jpg")
# Enable classifier-free guidance
torch.manual_seed(6)
image = pipe(
prompt=prompt, negative_prompt=negative_prompt,
num_inference_steps=30, cfg_scale=2.0
)
image.save("image_1024_cfg.jpg")
# Highres-fix
torch.manual_seed(7)
image = pipe(
prompt=prompt,
num_inference_steps=30,
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
)
image.save("image_2048_highres.jpg")

BIN
image_1024.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

BIN
image_1024_cfg.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

BIN
image_2048_highres.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 357 KiB