Merge pull request #164 from modelscope/Artiprocher-dev

FLUX highres-fix
This commit is contained in:
Zhongjie Duan
2024-08-20 13:40:23 +08:00
committed by GitHub
6 changed files with 130 additions and 27 deletions

View File

@@ -1,6 +1,7 @@
import torch import torch
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
from einops import rearrange from einops import rearrange
from .tiler import TileWorker
@@ -308,7 +309,60 @@ class FluxDiT(torch.nn.Module):
return hidden_states return hidden_states
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids, **kwargs): def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def tiled_forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
tile_size=128, tile_stride=64,
**kwargs
):
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
hidden_states,
tile_size,
tile_stride,
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
return hidden_states
def forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
tiled=False, tile_size=128, tile_stride=64,
**kwargs
):
if tiled:
return self.tiled_forward(
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
tile_size=tile_size, tile_stride=tile_stride,
**kwargs
)
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype)\ conditioning = self.time_embedder(timestep, hidden_states.dtype)\
+ self.guidance_embedder(guidance, hidden_states.dtype)\ + self.guidance_embedder(guidance, hidden_states.dtype)\
+ self.pooled_text_embedder(pooled_prompt_emb) + self.pooled_text_embedder(pooled_prompt_emb)

View File

@@ -64,20 +64,8 @@ class FluxImagePipeline(BasePipeline):
def prepare_extra_input(self, latents=None, guidance=0.0): def prepare_extra_input(self, latents=None, guidance=0.0):
batch_size, _, height, width = latents.shape latent_image_ids = self.dit.prepare_image_ids(latents)
latent_image_ids = torch.zeros(height // 2, width // 2, 3) guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
guidance = torch.Tensor([guidance] * batch_size).to(device=latents.device, dtype=latents.dtype)
return {"image_ids": latent_image_ids, "guidance": guidance} return {"image_ids": latent_image_ids, "guidance": guidance}
@@ -88,7 +76,9 @@ class FluxImagePipeline(BasePipeline):
local_prompts=[], local_prompts=[],
masks=[], masks=[],
mask_scales=[], mask_scales=[],
cfg_scale=0.0, negative_prompt="",
cfg_scale=1.0,
embedded_guidance=0.0,
input_image=None, input_image=None,
denoising_strength=1.0, denoising_strength=1.0,
height=1024, height=1024,
@@ -116,23 +106,32 @@ class FluxImagePipeline(BasePipeline):
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts # Encode prompts
prompt_emb = self.encode_prompt(prompt, positive=True) prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts] prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
# Extra input # Extra input
extra_input = self.prepare_extra_input(latents, guidance=cfg_scale) extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Denoise # Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device) timestep = timestep.unsqueeze(0).to(self.device)
# Inference (FLUX doesn't support classifier-free guidance) # Classifier-free guidance
inference_callback = lambda prompt_emb: self.dit( inference_callback = lambda prompt_emb_posi: self.dit(
latents, timestep=timestep, **prompt_emb, **tiler_kwargs, **extra_input latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
) )
noise_pred = self.control_noise_via_local_prompts(prompt_emb, prompt_emb_locals, masks, mask_scales, inference_callback) noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM # Iterate
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# UI # UI

View File

@@ -2,10 +2,20 @@
Image synthesis is the base feature of DiffSynth Studio. We can generate images with very high resolution. Image synthesis is the base feature of DiffSynth Studio. We can generate images with very high resolution.
### Example: FLUX
Example script: [`flux_text_to_image.py`](./flux_text_to_image.py)
|1024*1024 (original)|1024*1024 (classifier-free guidance)|2048*2048 (highres-fix)|
|-|-|-|
|![image_1024](https://github.com/user-attachments/assets/d8e66872-8739-43e4-8c2b-eda9daba0450)|![image_1024_cfg](https://github.com/user-attachments/assets/1073c70d-018f-47e4-9342-bc580b4c7c59)|![image_2048_highres](https://github.com/user-attachments/assets/8719c1a8-b341-48c1-a085-364c3a7d25f0)|
### Example: Stable Diffusion ### Example: Stable Diffusion
Example script: [`sd_text_to_image.py`](./sd_text_to_image.py) Example script: [`sd_text_to_image.py`](./sd_text_to_image.py)
LoRA Training: [`../train/stable_diffusion/`](../train/stable_diffusion/)
|512*512|1024*1024|2048*2048|4096*4096| |512*512|1024*1024|2048*2048|4096*4096|
|-|-|-|-| |-|-|-|-|
|![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)| |![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)|
@@ -14,6 +24,8 @@ Example script: [`sd_text_to_image.py`](./sd_text_to_image.py)
Example script: [`sdxl_text_to_image.py`](./sdxl_text_to_image.py) Example script: [`sdxl_text_to_image.py`](./sdxl_text_to_image.py)
LoRA Training: [`../train/stable_diffusion_xl/`](../train/stable_diffusion_xl/)
|1024*1024|2048*2048| |1024*1024|2048*2048|
|-|-| |-|-|
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)| |![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|

View File

@@ -12,9 +12,30 @@ model_manager.load_models([
]) ])
pipe = FluxImagePipeline.from_model_manager(model_manager) pipe = FluxImagePipeline.from_model_manager(model_manager)
prompt = "A captivating fantasy magic woman portrait set in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
torch.manual_seed(6) torch.manual_seed(6)
image = pipe( image = pipe(
"A captivating fantasy magic woman portrait set in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin.", prompt=prompt,
num_inference_steps=30 num_inference_steps=30,
) )
image.save("image_1024.jpg") image.save("image_1024.jpg")
# Enable classifier-free guidance
torch.manual_seed(6)
image = pipe(
prompt=prompt, negative_prompt=negative_prompt,
num_inference_steps=30, cfg_scale=2.0
)
image.save("image_1024_cfg.jpg")
# Highres-fix
torch.manual_seed(7)
image = pipe(
prompt=prompt,
num_inference_steps=30,
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
)
image.save("image_2048_highres.jpg")

View File

@@ -5,7 +5,7 @@ import streamlit as st
st.set_page_config(layout="wide") st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas from streamlit_drawable_canvas import st_canvas
from diffsynth.models import ModelManager from diffsynth.models import ModelManager
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
from diffsynth.data.video import crop_and_resize from diffsynth.data.video import crop_and_resize
@@ -49,13 +49,20 @@ config = {
"width": 1024, "width": 1024,
} }
}, },
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"fixed_parameters": {
"cfg_scale": 1.0,
}
}
} }
def load_model_list(model_type): def load_model_list(model_type):
folder = config[model_type]["model_folder"] folder = config[model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type in ["HunyuanDiT", "Kolors"]: if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list) file_list = sorted(file_list)
return file_list return file_list
@@ -85,6 +92,16 @@ def load_model(model_type, model_path):
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"), os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"), os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
]) ])
elif model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16
file_list = [
os.path.join(model_path, "text_encoder/model.safetensors"),
os.path.join(model_path, "text_encoder_2"),
]
for file_name in os.listdir(model_path):
if file_name.endswith(".safetensors"):
file_list.append(os.path.join(model_path, file_name))
model_manager.load_models(file_list)
else: else:
model_manager.load_model(model_path) model_manager.load_model(model_path)
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager) pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)