diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 487810a..6cca984 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -1,6 +1,7 @@ import torch from .sd3_dit import TimestepEmbeddings, AdaLayerNorm from einops import rearrange +from .tiler import TileWorker @@ -306,9 +307,62 @@ class FluxDiT(torch.nn.Module): def unpatchify(self, hidden_states, height, width): hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2) return hidden_states + + + def prepare_image_ids(self, latents): + batch_size, _, height, width = latents.shape + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) + + return latent_image_ids + + + def tiled_forward( + self, + hidden_states, + timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, + tile_size=128, tile_stride=64, + **kwargs + ): + # Due to the global positional embedding, we cannot implement layer-wise tiled forward. + hidden_states = TileWorker().tiled_forward( + lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None), + hidden_states, + tile_size, + tile_stride, + tile_device=hidden_states.device, + tile_dtype=hidden_states.dtype + ) + return hidden_states - def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids, **kwargs): + def forward( + self, + hidden_states, + timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, + tiled=False, tile_size=128, tile_stride=64, + **kwargs + ): + if tiled: + return self.tiled_forward( + hidden_states, + timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, + tile_size=tile_size, tile_stride=tile_stride, + **kwargs + ) + + if image_ids is None: + image_ids = self.prepare_image_ids(hidden_states) + conditioning = self.time_embedder(timestep, hidden_states.dtype)\ + self.guidance_embedder(guidance, hidden_states.dtype)\ + self.pooled_text_embedder(pooled_prompt_emb) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 0cc4a90..74de285 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -64,20 +64,8 @@ class FluxImagePipeline(BasePipeline): def prepare_extra_input(self, latents=None, guidance=0.0): - batch_size, _, height, width = latents.shape - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) - latent_image_ids = latent_image_ids.reshape( - batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) - - guidance = torch.Tensor([guidance] * batch_size).to(device=latents.device, dtype=latents.dtype) + latent_image_ids = self.dit.prepare_image_ids(latents) + guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) return {"image_ids": latent_image_ids, "guidance": guidance} @@ -88,7 +76,9 @@ class FluxImagePipeline(BasePipeline): local_prompts=[], masks=[], mask_scales=[], - cfg_scale=0.0, + negative_prompt="", + cfg_scale=1.0, + embedded_guidance=0.0, input_image=None, denoising_strength=1.0, height=1024, @@ -116,23 +106,32 @@ class FluxImagePipeline(BasePipeline): latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) # Encode prompts - prompt_emb = self.encode_prompt(prompt, positive=True) + prompt_emb_posi = self.encode_prompt(prompt, positive=True) + if cfg_scale != 1.0: + prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts] # Extra input - extra_input = self.prepare_extra_input(latents, guidance=cfg_scale) + extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) # Denoise for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(self.device) - # Inference (FLUX doesn't support classifier-free guidance) - inference_callback = lambda prompt_emb: self.dit( - latents, timestep=timestep, **prompt_emb, **tiler_kwargs, **extra_input + # Classifier-free guidance + inference_callback = lambda prompt_emb_posi: self.dit( + latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input ) - noise_pred = self.control_noise_via_local_prompts(prompt_emb, prompt_emb_locals, masks, mask_scales, inference_callback) + noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) + if cfg_scale != 1.0: + noise_pred_nega = self.dit( + latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input + ) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi - # DDIM + # Iterate latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) # UI diff --git a/examples/image_synthesis/README.md b/examples/image_synthesis/README.md index 5fcfa9d..2c751d0 100644 --- a/examples/image_synthesis/README.md +++ b/examples/image_synthesis/README.md @@ -2,10 +2,20 @@ Image synthesis is the base feature of DiffSynth Studio. We can generate images with very high resolution. +### Example: FLUX + +Example script: [`flux_text_to_image.py`](./flux_text_to_image.py) + +|1024*1024 (original)|1024*1024 (classifier-free guidance)|2048*2048 (highres-fix)| +|-|-|-| +|![image_1024](https://github.com/user-attachments/assets/d8e66872-8739-43e4-8c2b-eda9daba0450)|![image_1024_cfg](https://github.com/user-attachments/assets/1073c70d-018f-47e4-9342-bc580b4c7c59)|![image_2048_highres](https://github.com/user-attachments/assets/8719c1a8-b341-48c1-a085-364c3a7d25f0)| + ### Example: Stable Diffusion Example script: [`sd_text_to_image.py`](./sd_text_to_image.py) +LoRA Training: [`../train/stable_diffusion/`](../train/stable_diffusion/) + |512*512|1024*1024|2048*2048|4096*4096| |-|-|-|-| |![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)| @@ -14,6 +24,8 @@ Example script: [`sd_text_to_image.py`](./sd_text_to_image.py) Example script: [`sdxl_text_to_image.py`](./sdxl_text_to_image.py) +LoRA Training: [`../train/stable_diffusion_xl/`](../train/stable_diffusion_xl/) + |1024*1024|2048*2048| |-|-| |![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)| diff --git a/examples/image_synthesis/flux_text_to_image.py b/examples/image_synthesis/flux_text_to_image.py index 7cbfdc9..775c684 100644 --- a/examples/image_synthesis/flux_text_to_image.py +++ b/examples/image_synthesis/flux_text_to_image.py @@ -12,9 +12,30 @@ model_manager.load_models([ ]) pipe = FluxImagePipeline.from_model_manager(model_manager) +prompt = "A captivating fantasy magic woman portrait set in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw," + +# Disable classifier-free guidance (consistent with the original implementation of FLUX.1) torch.manual_seed(6) image = pipe( - "A captivating fantasy magic woman portrait set in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin.", - num_inference_steps=30 + prompt=prompt, + num_inference_steps=30, ) image.save("image_1024.jpg") + +# Enable classifier-free guidance +torch.manual_seed(6) +image = pipe( + prompt=prompt, negative_prompt=negative_prompt, + num_inference_steps=30, cfg_scale=2.0 +) +image.save("image_1024_cfg.jpg") + +# Highres-fix +torch.manual_seed(7) +image = pipe( + prompt=prompt, + num_inference_steps=30, + input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True +) +image.save("image_2048_highres.jpg") diff --git a/models/FLUX/Put Stable Diffusion checkpoints here.txt b/models/FLUX/Put Stable Diffusion checkpoints here.txt new file mode 100644 index 0000000..e69de29 diff --git a/pages/1_Image_Creator.py b/pages/1_Image_Creator.py index 2d13782..3b8ad45 100644 --- a/pages/1_Image_Creator.py +++ b/pages/1_Image_Creator.py @@ -5,7 +5,7 @@ import streamlit as st st.set_page_config(layout="wide") from streamlit_drawable_canvas import st_canvas from diffsynth.models import ModelManager -from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline +from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline from diffsynth.data.video import crop_and_resize @@ -49,13 +49,20 @@ config = { "width": 1024, } }, + "FLUX": { + "model_folder": "models/FLUX", + "pipeline_class": FluxImagePipeline, + "fixed_parameters": { + "cfg_scale": 1.0, + } + } } def load_model_list(model_type): folder = config[model_type]["model_folder"] file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] - if model_type in ["HunyuanDiT", "Kolors"]: + if model_type in ["HunyuanDiT", "Kolors", "FLUX"]: file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] file_list = sorted(file_list) return file_list @@ -85,6 +92,16 @@ def load_model(model_type, model_path): os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"), os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"), ]) + elif model_type == "FLUX": + model_manager.torch_dtype = torch.bfloat16 + file_list = [ + os.path.join(model_path, "text_encoder/model.safetensors"), + os.path.join(model_path, "text_encoder_2"), + ] + for file_name in os.listdir(model_path): + if file_name.endswith(".safetensors"): + file_list.append(os.path.join(model_path, file_name)) + model_manager.load_models(file_list) else: model_manager.load_model(model_path) pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)