Merge branch 'flux-refactor' into flux-refactor

This commit is contained in:
Zhongjie Duan
2025-06-24 15:19:42 +08:00
committed by GitHub
5 changed files with 271 additions and 50 deletions

View File

@@ -16,12 +16,59 @@ from ..schedulers import FlowMatchScheduler
from ..prompters import FluxPrompter
from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEncoder2, FluxDiT, FluxVAEEncoder, FluxVAEDecoder
from ..models.step1x_connector import Qwen2Connector
from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter
from ..models.tiler import FastTileWorker
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader
@dataclass
class ControlNetInput:
controlnet_id: int = 0
scale: float = 1.0
start: float = 1.0
end: float = 0.0
image: Image.Image = None
inpaint_mask: Image.Image = None
processor_id: str = None
class MultiControlNet(torch.nn.Module):
def __init__(self, models: list[FluxControlNet]):
super().__init__()
self.models = torch.nn.ModuleList(models)
def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs):
model = self.models[controlnet_input.controlnet_id]
res_stack, single_res_stack = model(
controlnet_conditioning=conditioning,
processor_id=controlnet_input.processor_id,
**kwargs
)
res_stack = [res * controlnet_input.scale for res in res_stack]
single_res_stack = [res * controlnet_input.scale for res in single_res_stack]
return res_stack, single_res_stack
def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs):
res_stack, single_res_stack = None, None
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
if progress > controlnet_input.start or progress < controlnet_input.end:
continue
res_stack_, single_res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs)
if res_stack is None:
res_stack = res_stack_
single_res_stack = single_res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
return res_stack, single_res_stack
class FluxImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
@@ -36,10 +83,13 @@ class FluxImagePipeline(BasePipeline):
self.dit: FluxDiT = None
self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None
self.controlnet: MultiControlNet = None
self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder = None
self.unit_runner = PipelineUnitRunner()
self.qwenvl = None
self.step1x_connector: Qwen2Connector = None
self.in_iteration_models = ("dit", "step1x_connector")
self.in_iteration_models = ("dit", "step1x_connector", "controlnet")
self.units = [
FluxImageUnit_ShapeChecker(),
FluxImageUnit_NoiseInitializer(),
@@ -47,6 +97,7 @@ class FluxImagePipeline(BasePipeline):
FluxImageUnit_InputImageEmbedder(),
FluxImageUnit_ImageIDs(),
FluxImageUnit_EmbeddedGuidanceEmbedder(),
FluxImageUnit_ControlNet(),
FluxImageUnit_IPAdapter(),
FluxImageUnit_EntityControl(),
FluxImageUnit_TeaCache(),
@@ -111,9 +162,15 @@ class FluxImagePipeline(BasePipeline):
pipe.prompter.fetch_models(pipe.text_encoder_1, pipe.text_encoder_2)
pipe.ipadapter = model_manager.fetch_model("flux_ipadapter")
pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
# Step1x
pipe.qwenvl = model_manager.fetch_model("qwenvl")
pipe.step1x_connector = model_manager.fetch_model("step1x_connector")
# ControlNet
controlnets = []
for model_name, model in zip(model_manager.model_name, model_manager.model):
if model_name == "flux_controlnet":
controlnets.append(model)
pipe.controlnet = MultiControlNet(controlnets)
return pipe
@@ -122,57 +179,57 @@ class FluxImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=3.5,
t5_sequence_length=512,
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
embedded_guidance: float = 3.5,
t5_sequence_length: int = 512,
# Image
input_image=None,
denoising_strength=1.0,
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height=1024,
width=1024,
height: int = 1024,
width: int = 1024,
# Randomness
seed=None,
rand_device: Optional[str] = "cpu",
seed: int = None,
rand_device: str = "cpu",
# Scheduler
sigma_shift=None,
sigma_shift: float = None,
# Steps
num_inference_steps=30,
num_inference_steps: int = 30,
# local prompts
multidiffusion_prompts=(),
multidiffusion_masks=(),
multidiffusion_scales=(),
# ControlNet
controlnet_inputs=None,
controlnet_inputs: list[ControlNetInput] = None,
# IP-Adapter
ipadapter_images=None,
ipadapter_scale=1.0,
ipadapter_images: list[Image.Image] = None,
ipadapter_scale: float = 1.0,
# EliGen
eligen_entity_prompts=None,
eligen_entity_masks=None,
eligen_enable_on_negative=False,
eligen_enable_inpaint=False,
eligen_entity_prompts: list[str] = None,
eligen_entity_masks: list[Image.Image] = None,
eligen_enable_on_negative: bool = False,
eligen_enable_inpaint: bool = False,
# InfiniteYou
infinityou_id_image=None,
infinityou_guidance=1.0,
infinityou_id_image: Image.Image = None,
infinityou_guidance: float = 1.0,
# Flex
flex_inpaint_image=None,
flex_inpaint_mask=None,
flex_control_image=None,
flex_control_strength=0.5,
flex_control_stop=0.5,
flex_inpaint_image: Image.Image = None,
flex_inpaint_mask: Image.Image = None,
flex_control_image: Image.Image = None,
flex_control_strength: float = 0.5,
flex_control_stop: float = 0.5,
# Step1x
step1x_reference_image=None,
step1x_reference_image: Image.Image = None,
# TeaCache
tea_cache_l1_thresh=None,
tea_cache_l1_thresh: float = None,
# Tile
tiled=False,
tile_size=128,
tile_stride=64,
tiled: bool = False,
tile_size: int = 128,
tile_stride: int = 64,
# Progress bar
progress_bar_cmd=tqdm,
progress_bar_cmd = tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
@@ -210,9 +267,9 @@ class FluxImagePipeline(BasePipeline):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
if cfg_scale != 1.0:
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep)
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
@@ -310,6 +367,49 @@ class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
return {"guidance": guidance}
class FluxImageUnit_ControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae_encoder",)
)
def apply_controlnet_mask_on_latents(self, pipe, latents, mask):
mask = (pipe.preprocess_image(mask) + 1) / 2
mask = mask.mean(dim=1, keepdim=True)
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
latents = torch.concat([latents, mask], dim=1)
return latents
def apply_controlnet_mask_on_image(self, pipe, image, mask):
mask = mask.resize(image.size)
mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()
image = np.array(image)
image[mask > 0] = 0
image = Image.fromarray(image)
return image
def process(self, pipe: FluxImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
if controlnet_inputs is None:
return {}
pipe.load_models_to_device(['vae_encoder'])
conditionings = []
for controlnet_input in controlnet_inputs:
image = controlnet_input.image
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
image = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
conditionings.append(image)
return {"controlnet_conditionings": conditionings}
class FluxImageUnit_IPAdapter(PipelineUnit):
def __init__(self):
super().__init__(
@@ -334,6 +434,7 @@ class FluxImageUnit_IPAdapter(PipelineUnit):
return inputs_shared, inputs_posi, inputs_nega
class FluxImageUnit_EntityControl(PipelineUnit):
def __init__(self):
super().__init__(
@@ -504,7 +605,8 @@ def model_fn_flux_image(
guidance=None,
text_ids=None,
image_ids=None,
controlnet_frames=None,
controlnet_inputs=None,
controlnet_conditionings=None,
tiled=False,
tile_size=128,
tile_stride=64,
@@ -520,11 +622,13 @@ def model_fn_flux_image(
step1x_mask=None,
step1x_reference_latents=None,
tea_cache: TeaCache = None,
progress_id=0,
num_inference_steps=1,
**kwargs
):
if tiled:
def flux_forward_fn(hl, hr, wl, wr):
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
tiled_controlnet_conditionings = [f[:, :, hl: hr, wl: wr] for f in controlnet_conditionings] if controlnet_conditionings is not None else None
return model_fn_flux_image(
dit=dit,
controlnet=controlnet,
@@ -535,7 +639,8 @@ def model_fn_flux_image(
guidance=guidance,
text_ids=text_ids,
image_ids=None,
controlnet_frames=tiled_controlnet_frames,
controlnet_inputs=controlnet_inputs,
controlnet_conditionings=tiled_controlnet_conditionings,
tiled=False,
**kwargs
)
@@ -551,7 +656,7 @@ def model_fn_flux_image(
hidden_states = latents
# ControlNet
if controlnet is not None and controlnet_frames is not None:
if controlnet is not None and controlnet_conditionings is not None:
controlnet_extra_kwargs = {
"hidden_states": hidden_states,
"timestep": timestep,
@@ -560,15 +665,18 @@ def model_fn_flux_image(
"guidance": guidance,
"text_ids": text_ids,
"image_ids": image_ids,
"controlnet_inputs": controlnet_inputs,
"tiled": tiled,
"tile_size": tile_size,
"tile_stride": tile_stride,
"progress_id": progress_id,
"num_inference_steps": num_inference_steps,
}
if id_emb is not None:
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_frames, **controlnet_extra_kwargs
controlnet_conditionings, **controlnet_extra_kwargs
)
# Flex
@@ -630,7 +738,7 @@ def model_fn_flux_image(
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
# Single Blocks
@@ -646,7 +754,7 @@ def model_fn_flux_image(
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
hidden_states = hidden_states[:, prompt_emb.shape[1]:]