mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support FLUX
This commit is contained in:
@@ -5,5 +5,6 @@ from .sdxl_video import SDXLVideoPipeline
|
||||
from .sd3_image import SD3ImagePipeline
|
||||
from .hunyuan_image import HunyuanDiTImagePipeline
|
||||
from .svd_video import SVDVideoPipeline
|
||||
from .flux_image import FluxImagePipeline
|
||||
from .pipeline_runner import SDVideoPipelineRunner
|
||||
KolorsImagePipeline = SDXLImagePipeline
|
||||
|
||||
@@ -22,7 +22,7 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def vae_output_to_image(self, vae_output):
|
||||
image = vae_output[0].cpu().permute(1, 2, 0).numpy()
|
||||
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
145
diffsynth/pipelines/flux_image.py
Normal file
145
diffsynth/pipelines/flux_image.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
|
||||
from ..prompters import FluxPrompter
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class FluxImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.prompter = FluxPrompter()
|
||||
# models
|
||||
self.text_encoder_1: FluxTextEncoder1 = None
|
||||
self.text_encoder_2: FluxTextEncoder2 = None
|
||||
self.dit: FluxDiT = None
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
|
||||
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
|
||||
self.dit = model_manager.fetch_model("flux_dit")
|
||||
self.vae_decoder = model_manager.fetch_model("flux_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
pipe = FluxImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.vae_output_to_image(image)
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True):
|
||||
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None, guidance=0.0):
|
||||
batch_size, _, height, width = latents.shape
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
guidance = torch.Tensor([guidance] * batch_size).to(device=latents.device, dtype=latents.dtype)
|
||||
return {"image_ids": latent_image_ids, "guidance": guidance}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
local_prompts=[],
|
||||
masks=[],
|
||||
mask_scales=[],
|
||||
cfg_scale=0.0,
|
||||
input_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=30,
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb = self.encode_prompt(prompt, positive=True)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=cfg_scale)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Inference (FLUX doesn't support classifier-free guidance)
|
||||
inference_callback = lambda prompt_emb: self.dit(
|
||||
latents, timestep=timestep, **prompt_emb, **tiler_kwargs, **extra_input
|
||||
)
|
||||
noise_pred = self.control_noise_via_local_prompts(prompt_emb, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||
|
||||
# DDIM
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
Reference in New Issue
Block a user