mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
flux controlnet
This commit is contained in:
@@ -15,12 +15,59 @@ from typing_extensions import Literal
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from ..prompters import FluxPrompter
|
||||
from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEncoder2, FluxDiT, FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.flux_controlnet import FluxControlNet
|
||||
from ..models.flux_ipadapter import FluxIpAdapter
|
||||
from ..models.tiler import FastTileWorker
|
||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetInput:
|
||||
controlnet_id: int = 0
|
||||
scale: float = 1.0
|
||||
start: float = 1.0
|
||||
end: float = 0.0
|
||||
image: Image.Image = None
|
||||
inpaint_mask: Image.Image = None
|
||||
processor_id: str = None
|
||||
|
||||
|
||||
|
||||
class MultiControlNet(torch.nn.Module):
|
||||
def __init__(self, models: list[FluxControlNet]):
|
||||
super().__init__()
|
||||
self.models = torch.nn.ModuleList(models)
|
||||
|
||||
def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs):
|
||||
model = self.models[controlnet_input.controlnet_id]
|
||||
res_stack, single_res_stack = model(
|
||||
controlnet_conditioning=conditioning,
|
||||
processor_id=controlnet_input.processor_id,
|
||||
**kwargs
|
||||
)
|
||||
res_stack = [res * controlnet_input.scale for res in res_stack]
|
||||
single_res_stack = [res * controlnet_input.scale for res in single_res_stack]
|
||||
return res_stack, single_res_stack
|
||||
|
||||
def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs):
|
||||
res_stack, single_res_stack = None, None
|
||||
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
|
||||
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
|
||||
if progress > controlnet_input.start or progress < controlnet_input.end:
|
||||
continue
|
||||
res_stack_, single_res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs)
|
||||
if res_stack is None:
|
||||
res_stack = res_stack_
|
||||
single_res_stack = single_res_stack_
|
||||
else:
|
||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
|
||||
return res_stack, single_res_stack
|
||||
|
||||
|
||||
|
||||
class FluxImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
@@ -35,8 +82,11 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.dit: FluxDiT = None
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.controlnet: MultiControlNet = None
|
||||
self.ipadapter: FluxIpAdapter = None
|
||||
self.ipadapter_image_encoder = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("dit", )
|
||||
self.in_iteration_models = ("dit", "controlnet")
|
||||
self.units = [
|
||||
FluxImageUnit_ShapeChecker(),
|
||||
FluxImageUnit_NoiseInitializer(),
|
||||
@@ -44,6 +94,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
FluxImageUnit_InputImageEmbedder(),
|
||||
FluxImageUnit_ImageIDs(),
|
||||
FluxImageUnit_EmbeddedGuidanceEmbedder(),
|
||||
FluxImageUnit_ControlNet(),
|
||||
FluxImageUnit_IPAdapter(),
|
||||
FluxImageUnit_EntityControl(),
|
||||
]
|
||||
@@ -105,6 +156,13 @@ class FluxImagePipeline(BasePipeline):
|
||||
pipe.prompter.fetch_models(pipe.text_encoder_1, pipe.text_encoder_2)
|
||||
pipe.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
||||
pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
||||
|
||||
# ControlNet
|
||||
controlnets = []
|
||||
for model_name, model in zip(model_manager.model_name, model_manager.model):
|
||||
if model_name == "flux_controlnet":
|
||||
controlnets.append(model)
|
||||
pipe.controlnet = MultiControlNet(controlnets)
|
||||
|
||||
return pipe
|
||||
|
||||
@@ -113,57 +171,57 @@ class FluxImagePipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=1.0,
|
||||
embedded_guidance=3.5,
|
||||
t5_sequence_length=512,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 1.0,
|
||||
embedded_guidance: float = 3.5,
|
||||
t5_sequence_length: int = 512,
|
||||
# Image
|
||||
input_image=None,
|
||||
denoising_strength=1.0,
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Shape
|
||||
height=1024,
|
||||
width=1024,
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
# Randomness
|
||||
seed=None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Scheduler
|
||||
sigma_shift=None,
|
||||
sigma_shift: float = None,
|
||||
# Steps
|
||||
num_inference_steps=30,
|
||||
num_inference_steps: int = 30,
|
||||
# local prompts
|
||||
multidiffusion_prompts=(),
|
||||
multidiffusion_masks=(),
|
||||
multidiffusion_scales=(),
|
||||
# ControlNet
|
||||
controlnet_inputs=None,
|
||||
controlnet_inputs: list[ControlNetInput] = None,
|
||||
# IP-Adapter
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
ipadapter_images: list[Image.Image] = None,
|
||||
ipadapter_scale: float = 1.0,
|
||||
# EliGen
|
||||
eligen_entity_prompts=None,
|
||||
eligen_entity_masks=None,
|
||||
eligen_enable_on_negative=False,
|
||||
eligen_enable_inpaint=False,
|
||||
eligen_entity_prompts: list[str] = None,
|
||||
eligen_entity_masks: list[Image.Image] = None,
|
||||
eligen_enable_on_negative: bool = False,
|
||||
eligen_enable_inpaint: bool = False,
|
||||
# InfiniteYou
|
||||
infinityou_id_image=None,
|
||||
infinityou_guidance=1.0,
|
||||
infinityou_id_image: Image.Image = None,
|
||||
infinityou_guidance: float = 1.0,
|
||||
# Flex
|
||||
flex_inpaint_image=None,
|
||||
flex_inpaint_mask=None,
|
||||
flex_control_image=None,
|
||||
flex_control_strength=0.5,
|
||||
flex_control_stop=0.5,
|
||||
flex_inpaint_image: Image.Image = None,
|
||||
flex_inpaint_mask: Image.Image = None,
|
||||
flex_control_image: Image.Image = None,
|
||||
flex_control_strength: float = 0.5,
|
||||
flex_control_stop: float = 0.5,
|
||||
# Step1x
|
||||
step1x_reference_image=None,
|
||||
step1x_reference_image: Image.Image = None,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh=None,
|
||||
tea_cache_l1_thresh: float = None,
|
||||
# Tile
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
tiled: bool = False,
|
||||
tile_size: int = 128,
|
||||
tile_stride: int = 64,
|
||||
# Progress bar
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||
@@ -201,9 +259,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep)
|
||||
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
@@ -301,6 +359,49 @@ class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
|
||||
return {"guidance": guidance}
|
||||
|
||||
|
||||
|
||||
class FluxImageUnit_ControlNet(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
def apply_controlnet_mask_on_latents(self, pipe, latents, mask):
|
||||
mask = (pipe.preprocess_image(mask) + 1) / 2
|
||||
mask = mask.mean(dim=1, keepdim=True)
|
||||
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
|
||||
latents = torch.concat([latents, mask], dim=1)
|
||||
return latents
|
||||
|
||||
def apply_controlnet_mask_on_image(self, pipe, image, mask):
|
||||
mask = mask.resize(image.size)
|
||||
mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()
|
||||
image = np.array(image)
|
||||
image[mask > 0] = 0
|
||||
image = Image.fromarray(image)
|
||||
return image
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
|
||||
if controlnet_inputs is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(['vae_encoder'])
|
||||
conditionings = []
|
||||
for controlnet_input in controlnet_inputs:
|
||||
image = controlnet_input.image
|
||||
if controlnet_input.inpaint_mask is not None:
|
||||
image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
|
||||
|
||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
image = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
if controlnet_input.inpaint_mask is not None:
|
||||
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
|
||||
conditionings.append(image)
|
||||
return {"controlnet_conditionings": conditionings}
|
||||
|
||||
|
||||
|
||||
class FluxImageUnit_IPAdapter(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -325,6 +426,7 @@ class FluxImageUnit_IPAdapter(PipelineUnit):
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
|
||||
class FluxImageUnit_EntityControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -428,7 +530,8 @@ def model_fn_flux_image(
|
||||
guidance=None,
|
||||
text_ids=None,
|
||||
image_ids=None,
|
||||
controlnet_frames=None,
|
||||
controlnet_inputs=None,
|
||||
controlnet_conditionings=None,
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
@@ -444,11 +547,13 @@ def model_fn_flux_image(
|
||||
step1x_mask=None,
|
||||
step1x_reference_latents=None,
|
||||
tea_cache: TeaCache = None,
|
||||
progress_id=0,
|
||||
num_inference_steps=1,
|
||||
**kwargs
|
||||
):
|
||||
if tiled:
|
||||
def flux_forward_fn(hl, hr, wl, wr):
|
||||
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
|
||||
tiled_controlnet_conditionings = [f[:, :, hl: hr, wl: wr] for f in controlnet_conditionings] if controlnet_conditionings is not None else None
|
||||
return model_fn_flux_image(
|
||||
dit=dit,
|
||||
controlnet=controlnet,
|
||||
@@ -459,7 +564,8 @@ def model_fn_flux_image(
|
||||
guidance=guidance,
|
||||
text_ids=text_ids,
|
||||
image_ids=None,
|
||||
controlnet_frames=tiled_controlnet_frames,
|
||||
controlnet_inputs=controlnet_inputs,
|
||||
controlnet_conditionings=tiled_controlnet_conditionings,
|
||||
tiled=False,
|
||||
**kwargs
|
||||
)
|
||||
@@ -475,7 +581,7 @@ def model_fn_flux_image(
|
||||
hidden_states = latents
|
||||
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
if controlnet is not None and controlnet_conditionings is not None:
|
||||
controlnet_extra_kwargs = {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
@@ -484,15 +590,18 @@ def model_fn_flux_image(
|
||||
"guidance": guidance,
|
||||
"text_ids": text_ids,
|
||||
"image_ids": image_ids,
|
||||
"controlnet_inputs": controlnet_inputs,
|
||||
"tiled": tiled,
|
||||
"tile_size": tile_size,
|
||||
"tile_stride": tile_stride,
|
||||
"progress_id": progress_id,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
if id_emb is not None:
|
||||
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
|
||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
||||
controlnet_frames, **controlnet_extra_kwargs
|
||||
controlnet_conditionings, **controlnet_extra_kwargs
|
||||
)
|
||||
|
||||
# Flex
|
||||
@@ -554,7 +663,7 @@ def model_fn_flux_image(
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None:
|
||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||
|
||||
# Single Blocks
|
||||
@@ -570,7 +679,7 @@ def model_fn_flux_image(
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None:
|
||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user