mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
step1x, teacache, flex refactor
This commit is contained in:
@@ -15,6 +15,7 @@ from typing_extensions import Literal
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from ..prompters import FluxPrompter
|
||||
from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEncoder2, FluxDiT, FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
from ..models.tiler import FastTileWorker
|
||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader
|
||||
@@ -36,7 +37,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("dit", )
|
||||
self.qwenvl = None
|
||||
self.step1x_connector: Qwen2Connector = None
|
||||
self.in_iteration_models = ("dit", "step1x_connector")
|
||||
self.units = [
|
||||
FluxImageUnit_ShapeChecker(),
|
||||
FluxImageUnit_NoiseInitializer(),
|
||||
@@ -46,6 +49,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
FluxImageUnit_EmbeddedGuidanceEmbedder(),
|
||||
FluxImageUnit_IPAdapter(),
|
||||
FluxImageUnit_EntityControl(),
|
||||
FluxImageUnit_TeaCache(),
|
||||
FluxImageUnit_Flex(),
|
||||
FluxImageUnit_Step1x(),
|
||||
]
|
||||
self.model_fn = model_fn_flux_image
|
||||
|
||||
@@ -105,6 +111,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
pipe.prompter.fetch_models(pipe.text_encoder_1, pipe.text_encoder_2)
|
||||
pipe.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
||||
pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
||||
# Step1x
|
||||
pipe.qwenvl = model_manager.fetch_model("qwenvl")
|
||||
pipe.step1x_connector = model_manager.fetch_model("step1x_connector")
|
||||
|
||||
return pipe
|
||||
|
||||
@@ -374,6 +383,73 @@ class FluxImageUnit_EntityControl(PipelineUnit):
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class FluxImageUnit_Step1x(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder"))
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict):
|
||||
image = inputs_shared.get("step1x_reference_image",None)
|
||||
if image is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
else:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt = inputs_posi["prompt"]
|
||||
nega_prompt = inputs_nega["negative_prompt"]
|
||||
captions = [prompt, nega_prompt]
|
||||
ref_images = [image, image]
|
||||
embs, masks = pipe.qwenvl(captions, ref_images)
|
||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
image = pipe.vae_encoder(image)
|
||||
inputs_posi.update({"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image})
|
||||
inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image})
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class FluxImageUnit_TeaCache(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh"))
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh):
|
||||
if tea_cache_l1_thresh is None:
|
||||
return {}
|
||||
else:
|
||||
return {"tea_cache": TeaCache(num_inference_steps=num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh)}
|
||||
|
||||
class FluxImageUnit_Flex(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride):
|
||||
if pipe.dit.input_dim == 196:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if flex_inpaint_image is None:
|
||||
flex_inpaint_image = torch.zeros_like(latents)
|
||||
else:
|
||||
flex_inpaint_image = pipe.preprocess_image(flex_inpaint_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
flex_inpaint_image = pipe.vae_encoder(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if flex_inpaint_mask is None:
|
||||
flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
|
||||
else:
|
||||
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
|
||||
flex_inpaint_mask = pipe.preprocess_image(flex_inpaint_mask).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
|
||||
flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
|
||||
if flex_control_image is None:
|
||||
flex_control_image = torch.zeros_like(latents)
|
||||
else:
|
||||
flex_control_image = pipe.preprocess_image(flex_control_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
flex_control_image = pipe.vae_encoder(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
|
||||
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
|
||||
flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
|
||||
flex_control_stop_timestep = pipe.scheduler.timesteps[int(flex_control_stop * (len(pipe.scheduler.timesteps) - 1))]
|
||||
return {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
Reference in New Issue
Block a user