mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
flux-refactor
This commit is contained in:
@@ -18,10 +18,13 @@ from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEnc
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
from ..models.flux_controlnet import FluxControlNet
|
||||
from ..models.flux_ipadapter import FluxIpAdapter
|
||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||
from ..models.tiler import FastTileWorker
|
||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader
|
||||
|
||||
from ..vram_management import gradient_checkpoint_forward
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -89,6 +92,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.qwenvl = None
|
||||
self.step1x_connector: Qwen2Connector = None
|
||||
self.infinityou_processor: InfinitYou = None
|
||||
self.image_proj_model: InfiniteYouImageProjector = None
|
||||
self.in_iteration_models = ("dit", "step1x_connector", "controlnet")
|
||||
self.units = [
|
||||
FluxImageUnit_ShapeChecker(),
|
||||
@@ -209,7 +214,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
# ControlNet
|
||||
controlnet_inputs: list[ControlNetInput] = None,
|
||||
# IP-Adapter
|
||||
ipadapter_images: list[Image.Image] = None,
|
||||
ipadapter_images: Union[list[Image.Image], Image.Image] = None,
|
||||
ipadapter_scale: float = 1.0,
|
||||
# EliGen
|
||||
eligen_entity_prompts: list[str] = None,
|
||||
@@ -426,6 +431,8 @@ class FluxImageUnit_IPAdapter(PipelineUnit):
|
||||
ipadapter_images, ipadapter_scale = inputs_shared.get("ipadapter_images", None), inputs_shared.get("ipadapter_scale", 1.0)
|
||||
if ipadapter_images is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
if not isinstance(ipadapter_images, list):
|
||||
ipadapter_images = [ipadapter_images]
|
||||
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
images = [image.convert("RGB").resize((384, 384), resample=3) for image in ipadapter_images]
|
||||
@@ -700,6 +707,8 @@ def model_fn_flux_image(
|
||||
tea_cache: TeaCache = None,
|
||||
progress_id=0,
|
||||
num_inference_steps=1,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs
|
||||
):
|
||||
if tiled:
|
||||
@@ -805,13 +814,16 @@ def model_fn_flux_image(
|
||||
else:
|
||||
# Joint Blocks
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states, prompt_emb = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None),
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None:
|
||||
@@ -821,13 +833,16 @@ def model_fn_flux_image(
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
num_joint_blocks = len(dit.blocks)
|
||||
for block_id, block in enumerate(dit.single_blocks):
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states, prompt_emb = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None:
|
||||
|
||||
@@ -178,7 +178,7 @@ class ModelConfig:
|
||||
skip_download = dist.get_rank() != 0
|
||||
|
||||
# Check whether the origin path is a folder
|
||||
if self.origin_file_pattern is None:
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
self.origin_file_pattern = ""
|
||||
allow_file_pattern = None
|
||||
is_folder = True
|
||||
|
||||
Reference in New Issue
Block a user