support step1x

This commit is contained in:
Artiprocher
2025-04-27 20:37:43 +08:00
parent 4f01b37a2a
commit 109a0a0d49
5 changed files with 947 additions and 10 deletions

View File

@@ -1,4 +1,5 @@
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
from ..models.step1x_connector import Qwen2Connector
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler
@@ -32,7 +33,9 @@ class FluxImagePipeline(BasePipeline):
self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder: SiglipVisionModel = None
self.infinityou_processor: InfinitYou = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
self.qwenvl = None
self.step1x_connector: Qwen2Connector = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder', 'step1x_connector']
def enable_vram_management(self, num_persistent_param_in_dit=None):
@@ -167,6 +170,10 @@ class FluxImagePipeline(BasePipeline):
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
if self.image_proj_model is not None:
self.infinityou_processor = InfinitYou(device=self.device)
# Step1x
self.qwenvl = model_manager.fetch_model("qwenvl")
self.step1x_connector = model_manager.fetch_model("step1x_connector")
@staticmethod
@@ -191,10 +198,13 @@ class FluxImagePipeline(BasePipeline):
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
if self.text_encoder_1 is not None and self.text_encoder_2 is not None:
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
else:
return {}
def prepare_extra_input(self, latents=None, guidance=1.0):
@@ -388,6 +398,17 @@ class FluxImagePipeline(BasePipeline):
else:
flex_kwargs = {}
return flex_kwargs
def prepare_step1x_kwargs(self, prompt, negative_prompt, image):
if image is None:
return {}, {}
captions = [prompt, negative_prompt]
ref_images = [image, image]
embs, masks = self.qwenvl(captions, ref_images)
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
image = self.encode_image(image)
return {"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image}, {"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image}
@torch.no_grad()
@@ -432,6 +453,8 @@ class FluxImagePipeline(BasePipeline):
flex_control_image=None,
flex_control_strength=0.5,
flex_control_stop=0.5,
# Step1x
step1x_reference_image=None,
# TeaCache
tea_cache_l1_thresh=None,
# Tile
@@ -472,7 +495,10 @@ class FluxImagePipeline(BasePipeline):
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
# Flex
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, **tiler_kwargs)
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength=flex_control_strength, flex_control_stop=flex_control_stop, **tiler_kwargs)
# Step1x
step1x_kwargs_posi, step1x_kwargs_nega = self.prepare_step1x_kwargs(prompt, negative_prompt, image=step1x_reference_image)
# TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
@@ -484,9 +510,9 @@ class FluxImagePipeline(BasePipeline):
# Positive side
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_posi,
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -501,9 +527,9 @@ class FluxImagePipeline(BasePipeline):
if cfg_scale != 1.0:
# Negative side
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
@@ -623,6 +649,7 @@ class TeaCache:
def lets_dance_flux(
dit: FluxDiT,
controlnet: FluxMultiControlNetManager = None,
step1x_connector: Qwen2Connector = None,
hidden_states=None,
timestep=None,
prompt_emb=None,
@@ -642,6 +669,9 @@ def lets_dance_flux(
flex_condition=None,
flex_uncondition=None,
flex_control_stop_timestep=None,
step1x_llm_embedding=None,
step1x_mask=None,
step1x_reference_latents=None,
tea_cache: TeaCache = None,
**kwargs
):
@@ -699,6 +729,11 @@ def lets_dance_flux(
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
else:
hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
# Step1x
if step1x_llm_embedding is not None:
prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask)
text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device)
if image_ids is None:
image_ids = dit.prepare_image_ids(hidden_states)
@@ -710,6 +745,14 @@ def lets_dance_flux(
height, width = hidden_states.shape[-2:]
hidden_states = dit.patchify(hidden_states)
# Step1x
if step1x_reference_latents is not None:
step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents)
step1x_reference_latents = dit.patchify(step1x_reference_latents)
image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2)
hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1)
hidden_states = dit.x_embedder(hidden_states)
if entity_prompt_emb is not None and entity_masks is not None:
@@ -764,6 +807,11 @@ def lets_dance_flux(
hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states)
# Step1x
if step1x_reference_latents is not None:
hidden_states = hidden_states[:, :hidden_states.shape[1] // 2]
hidden_states = dit.unpatchify(hidden_states, height, width)
return hidden_states