mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support flux highresfix
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
|
||||
from einops import rearrange
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
|
||||
@@ -306,9 +307,62 @@ class FluxDiT(torch.nn.Module):
|
||||
def unpatchify(self, hidden_states, height, width):
|
||||
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def prepare_image_ids(self, latents):
|
||||
batch_size, _, height, width = latents.shape
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
return latent_image_ids
|
||||
|
||||
|
||||
def tiled_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
||||
tile_size=128, tile_stride=64,
|
||||
**kwargs
|
||||
):
|
||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
|
||||
hidden_states,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids, **kwargs):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||
tiled=False, tile_size=128, tile_stride=64,
|
||||
**kwargs
|
||||
):
|
||||
if tiled:
|
||||
return self.tiled_forward(
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
||||
tile_size=tile_size, tile_stride=tile_stride,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = self.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
|
||||
+ self.guidance_embedder(guidance, hidden_states.dtype)\
|
||||
+ self.pooled_text_embedder(pooled_prompt_emb)
|
||||
|
||||
@@ -64,20 +64,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None, guidance=0.0):
|
||||
batch_size, _, height, width = latents.shape
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
guidance = torch.Tensor([guidance] * batch_size).to(device=latents.device, dtype=latents.dtype)
|
||||
latent_image_ids = self.dit.prepare_image_ids(latents)
|
||||
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
|
||||
return {"image_ids": latent_image_ids, "guidance": guidance}
|
||||
|
||||
|
||||
@@ -88,7 +76,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
local_prompts=[],
|
||||
masks=[],
|
||||
mask_scales=[],
|
||||
cfg_scale=0.0,
|
||||
negative_prompt="",
|
||||
cfg_scale=1.0,
|
||||
embedded_guidance=0.0,
|
||||
input_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
@@ -116,23 +106,32 @@ class FluxImagePipeline(BasePipeline):
|
||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb = self.encode_prompt(prompt, positive=True)
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=cfg_scale)
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Inference (FLUX doesn't support classifier-free guidance)
|
||||
inference_callback = lambda prompt_emb: self.dit(
|
||||
latents, timestep=timestep, **prompt_emb, **tiler_kwargs, **extra_input
|
||||
# Classifier-free guidance
|
||||
inference_callback = lambda prompt_emb_posi: self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
|
||||
)
|
||||
noise_pred = self.control_noise_via_local_prompts(prompt_emb, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# DDIM
|
||||
# Iterate
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# UI
|
||||
|
||||
@@ -2,10 +2,20 @@
|
||||
|
||||
Image synthesis is the base feature of DiffSynth Studio. We can generate images with very high resolution.
|
||||
|
||||
### Example: FLUX
|
||||
|
||||
Example script: [`flux_text_to_image.py`](./flux_text_to_image.py)
|
||||
|
||||
|1024*1024 (original)|1024*1024 (classifier-free guidance)|2048*2048 (highres-fix)|
|
||||
|-|-|-|
|
||||
||||
|
||||
|
||||
### Example: Stable Diffusion
|
||||
|
||||
Example script: [`sd_text_to_image.py`](./sd_text_to_image.py)
|
||||
|
||||
LoRA Training: [`../train/stable_diffusion/`](../train/stable_diffusion/)
|
||||
|
||||
|512*512|1024*1024|2048*2048|4096*4096|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
@@ -14,6 +24,8 @@ Example script: [`sd_text_to_image.py`](./sd_text_to_image.py)
|
||||
|
||||
Example script: [`sdxl_text_to_image.py`](./sdxl_text_to_image.py)
|
||||
|
||||
LoRA Training: [`../train/stable_diffusion_xl/`](../train/stable_diffusion_xl/)
|
||||
|
||||
|1024*1024|2048*2048|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
@@ -12,9 +12,30 @@ model_manager.load_models([
|
||||
])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
prompt = "A captivating fantasy magic woman portrait set in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
|
||||
|
||||
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
|
||||
torch.manual_seed(6)
|
||||
image = pipe(
|
||||
"A captivating fantasy magic woman portrait set in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin.",
|
||||
num_inference_steps=30
|
||||
prompt=prompt,
|
||||
num_inference_steps=30,
|
||||
)
|
||||
image.save("image_1024.jpg")
|
||||
|
||||
# Enable classifier-free guidance
|
||||
torch.manual_seed(6)
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt,
|
||||
num_inference_steps=30, cfg_scale=2.0
|
||||
)
|
||||
image.save("image_1024_cfg.jpg")
|
||||
|
||||
# Highres-fix
|
||||
torch.manual_seed(7)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=30,
|
||||
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
|
||||
)
|
||||
image.save("image_2048_highres.jpg")
|
||||
|
||||
BIN
image_1024.jpg
Normal file
BIN
image_1024.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 102 KiB |
BIN
image_1024_cfg.jpg
Normal file
BIN
image_1024_cfg.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 112 KiB |
BIN
image_2048_highres.jpg
Normal file
BIN
image_2048_highres.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 357 KiB |
Reference in New Issue
Block a user