Merge branch 'flux-refactor' into flux-refactor

This commit is contained in:
Zhongjie Duan
2025-06-24 15:19:42 +08:00
committed by GitHub
5 changed files with 271 additions and 50 deletions

View File

@@ -16,12 +16,59 @@ from ..schedulers import FlowMatchScheduler
from ..prompters import FluxPrompter
from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEncoder2, FluxDiT, FluxVAEEncoder, FluxVAEDecoder
from ..models.step1x_connector import Qwen2Connector
from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter
from ..models.tiler import FastTileWorker
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader
@dataclass
class ControlNetInput:
controlnet_id: int = 0
scale: float = 1.0
start: float = 1.0
end: float = 0.0
image: Image.Image = None
inpaint_mask: Image.Image = None
processor_id: str = None
class MultiControlNet(torch.nn.Module):
def __init__(self, models: list[FluxControlNet]):
super().__init__()
self.models = torch.nn.ModuleList(models)
def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs):
model = self.models[controlnet_input.controlnet_id]
res_stack, single_res_stack = model(
controlnet_conditioning=conditioning,
processor_id=controlnet_input.processor_id,
**kwargs
)
res_stack = [res * controlnet_input.scale for res in res_stack]
single_res_stack = [res * controlnet_input.scale for res in single_res_stack]
return res_stack, single_res_stack
def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs):
res_stack, single_res_stack = None, None
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
if progress > controlnet_input.start or progress < controlnet_input.end:
continue
res_stack_, single_res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs)
if res_stack is None:
res_stack = res_stack_
single_res_stack = single_res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
return res_stack, single_res_stack
class FluxImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
@@ -36,10 +83,13 @@ class FluxImagePipeline(BasePipeline):
self.dit: FluxDiT = None
self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None
self.controlnet: MultiControlNet = None
self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder = None
self.unit_runner = PipelineUnitRunner()
self.qwenvl = None
self.step1x_connector: Qwen2Connector = None
self.in_iteration_models = ("dit", "step1x_connector")
self.in_iteration_models = ("dit", "step1x_connector", "controlnet")
self.units = [
FluxImageUnit_ShapeChecker(),
FluxImageUnit_NoiseInitializer(),
@@ -47,6 +97,7 @@ class FluxImagePipeline(BasePipeline):
FluxImageUnit_InputImageEmbedder(),
FluxImageUnit_ImageIDs(),
FluxImageUnit_EmbeddedGuidanceEmbedder(),
FluxImageUnit_ControlNet(),
FluxImageUnit_IPAdapter(),
FluxImageUnit_EntityControl(),
FluxImageUnit_TeaCache(),
@@ -111,9 +162,15 @@ class FluxImagePipeline(BasePipeline):
pipe.prompter.fetch_models(pipe.text_encoder_1, pipe.text_encoder_2)
pipe.ipadapter = model_manager.fetch_model("flux_ipadapter")
pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
# Step1x
pipe.qwenvl = model_manager.fetch_model("qwenvl")
pipe.step1x_connector = model_manager.fetch_model("step1x_connector")
# ControlNet
controlnets = []
for model_name, model in zip(model_manager.model_name, model_manager.model):
if model_name == "flux_controlnet":
controlnets.append(model)
pipe.controlnet = MultiControlNet(controlnets)
return pipe
@@ -122,57 +179,57 @@ class FluxImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=3.5,
t5_sequence_length=512,
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
embedded_guidance: float = 3.5,
t5_sequence_length: int = 512,
# Image
input_image=None,
denoising_strength=1.0,
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height=1024,
width=1024,
height: int = 1024,
width: int = 1024,
# Randomness
seed=None,
rand_device: Optional[str] = "cpu",
seed: int = None,
rand_device: str = "cpu",
# Scheduler
sigma_shift=None,
sigma_shift: float = None,
# Steps
num_inference_steps=30,
num_inference_steps: int = 30,
# local prompts
multidiffusion_prompts=(),
multidiffusion_masks=(),
multidiffusion_scales=(),
# ControlNet
controlnet_inputs=None,
controlnet_inputs: list[ControlNetInput] = None,
# IP-Adapter
ipadapter_images=None,
ipadapter_scale=1.0,
ipadapter_images: list[Image.Image] = None,
ipadapter_scale: float = 1.0,
# EliGen
eligen_entity_prompts=None,
eligen_entity_masks=None,
eligen_enable_on_negative=False,
eligen_enable_inpaint=False,
eligen_entity_prompts: list[str] = None,
eligen_entity_masks: list[Image.Image] = None,
eligen_enable_on_negative: bool = False,
eligen_enable_inpaint: bool = False,
# InfiniteYou
infinityou_id_image=None,
infinityou_guidance=1.0,
infinityou_id_image: Image.Image = None,
infinityou_guidance: float = 1.0,
# Flex
flex_inpaint_image=None,
flex_inpaint_mask=None,
flex_control_image=None,
flex_control_strength=0.5,
flex_control_stop=0.5,
flex_inpaint_image: Image.Image = None,
flex_inpaint_mask: Image.Image = None,
flex_control_image: Image.Image = None,
flex_control_strength: float = 0.5,
flex_control_stop: float = 0.5,
# Step1x
step1x_reference_image=None,
step1x_reference_image: Image.Image = None,
# TeaCache
tea_cache_l1_thresh=None,
tea_cache_l1_thresh: float = None,
# Tile
tiled=False,
tile_size=128,
tile_stride=64,
tiled: bool = False,
tile_size: int = 128,
tile_stride: int = 64,
# Progress bar
progress_bar_cmd=tqdm,
progress_bar_cmd = tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
@@ -210,9 +267,9 @@ class FluxImagePipeline(BasePipeline):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
if cfg_scale != 1.0:
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep)
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
@@ -310,6 +367,49 @@ class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
return {"guidance": guidance}
class FluxImageUnit_ControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae_encoder",)
)
def apply_controlnet_mask_on_latents(self, pipe, latents, mask):
mask = (pipe.preprocess_image(mask) + 1) / 2
mask = mask.mean(dim=1, keepdim=True)
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
latents = torch.concat([latents, mask], dim=1)
return latents
def apply_controlnet_mask_on_image(self, pipe, image, mask):
mask = mask.resize(image.size)
mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()
image = np.array(image)
image[mask > 0] = 0
image = Image.fromarray(image)
return image
def process(self, pipe: FluxImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
if controlnet_inputs is None:
return {}
pipe.load_models_to_device(['vae_encoder'])
conditionings = []
for controlnet_input in controlnet_inputs:
image = controlnet_input.image
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
image = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
conditionings.append(image)
return {"controlnet_conditionings": conditionings}
class FluxImageUnit_IPAdapter(PipelineUnit):
def __init__(self):
super().__init__(
@@ -334,6 +434,7 @@ class FluxImageUnit_IPAdapter(PipelineUnit):
return inputs_shared, inputs_posi, inputs_nega
class FluxImageUnit_EntityControl(PipelineUnit):
def __init__(self):
super().__init__(
@@ -504,7 +605,8 @@ def model_fn_flux_image(
guidance=None,
text_ids=None,
image_ids=None,
controlnet_frames=None,
controlnet_inputs=None,
controlnet_conditionings=None,
tiled=False,
tile_size=128,
tile_stride=64,
@@ -520,11 +622,13 @@ def model_fn_flux_image(
step1x_mask=None,
step1x_reference_latents=None,
tea_cache: TeaCache = None,
progress_id=0,
num_inference_steps=1,
**kwargs
):
if tiled:
def flux_forward_fn(hl, hr, wl, wr):
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
tiled_controlnet_conditionings = [f[:, :, hl: hr, wl: wr] for f in controlnet_conditionings] if controlnet_conditionings is not None else None
return model_fn_flux_image(
dit=dit,
controlnet=controlnet,
@@ -535,7 +639,8 @@ def model_fn_flux_image(
guidance=guidance,
text_ids=text_ids,
image_ids=None,
controlnet_frames=tiled_controlnet_frames,
controlnet_inputs=controlnet_inputs,
controlnet_conditionings=tiled_controlnet_conditionings,
tiled=False,
**kwargs
)
@@ -551,7 +656,7 @@ def model_fn_flux_image(
hidden_states = latents
# ControlNet
if controlnet is not None and controlnet_frames is not None:
if controlnet is not None and controlnet_conditionings is not None:
controlnet_extra_kwargs = {
"hidden_states": hidden_states,
"timestep": timestep,
@@ -560,15 +665,18 @@ def model_fn_flux_image(
"guidance": guidance,
"text_ids": text_ids,
"image_ids": image_ids,
"controlnet_inputs": controlnet_inputs,
"tiled": tiled,
"tile_size": tile_size,
"tile_stride": tile_stride,
"progress_id": progress_id,
"num_inference_steps": num_inference_steps,
}
if id_emb is not None:
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_frames, **controlnet_extra_kwargs
controlnet_conditionings, **controlnet_extra_kwargs
)
# Flex
@@ -630,7 +738,7 @@ def model_fn_flux_image(
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
# Single Blocks
@@ -646,7 +754,7 @@ def model_fn_flux_image(
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
hidden_states = hidden_states[:, prompt_emb.shape[1]:]

View File

@@ -0,0 +1,37 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
import numpy as np
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors"),
],
)
image_1 = pipe(
prompt="a cat sitting on a chair",
height=1024, width=1024,
seed=8, rand_device="cuda",
)
image_1.save("image_1.jpg")
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask[100:350, 350: -300] = 255
mask = Image.fromarray(mask)
mask.save("mask.jpg")
image_2 = pipe(
prompt="a cat sitting on a chair, wearing sunglasses",
controlnet_inputs=[ControlNetInput(image=image_1, inpaint_mask=mask, scale=0.9)],
height=1024, width=1024,
seed=9, rand_device="cuda",
)
image_2.save("image_2.jpg")

View File

@@ -0,0 +1,40 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from diffsynth.controlnets.processors import Annotator
from diffsynth import download_models
download_models(["Annotators:Depth"])
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors"),
],
)
image_1 = pipe(
prompt="a beautiful Asian girl, full body, red dress, summer",
height=1024, width=1024,
seed=6, rand_device="cuda",
)
image_1.save("image_1.jpg")
image_canny = Annotator("canny")(image_1)
image_depth = Annotator("depth")(image_1)
image_2 = pipe(
prompt="a beautiful Asian girl, full body, red dress, winter",
controlnet_inputs=[
ControlNetInput(image=image_canny, scale=0.3, processor_id="canny"),
ControlNetInput(image=image_depth, scale=0.3, processor_id="depth"),
],
height=1024, width=1024,
seed=7, rand_device="cuda",
)
image_2.save("image_2.jpg")

View File

@@ -0,0 +1,33 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors"),
],
)
image_1 = pipe(
prompt="a photo of a cat, highly detailed",
height=768, width=768,
seed=0, rand_device="cuda",
)
image_1.save("image_1.jpg")
image_1 = image_1.resize((2048, 2048))
image_2 = pipe(
prompt="a photo of a cat, highly detailed",
controlnet_inputs=[ControlNetInput(image=image_1, scale=0.7)],
input_image=image_1,
denoising_strength=0.99,
height=2048, width=2048, tiled=True,
seed=1, rand_device="cuda",
)
image_2.save("image_2.jpg")

View File

@@ -1,8 +1,5 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = FluxImagePipeline.from_pretrained(
@@ -16,8 +13,14 @@ pipe = FluxImagePipeline.from_pretrained(
],
)
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
image = pipe(prompt=prompt, seed=0)
image.save("flux.jpg")
image = pipe(
prompt="a girl",
seed=0,
prompt=prompt, negative_prompt=negative_prompt,
seed=0, cfg_scale=2, num_inference_steps=50,
)
image.save("0.jpg")
image.save("flux_cfg.jpg")