tmp commit for nexus-gen edit

This commit is contained in:
mi804
2025-07-28 16:18:38 +08:00
parent b8f05bb342
commit 2861ec4d9f
8 changed files with 1721 additions and 6 deletions

View File

@@ -22,6 +22,8 @@ from ..models.flux_value_control import MultiValueEncoder
from ..models.flux_infiniteyou import InfiniteYouImageProjector
from ..models.flux_lora_encoder import FluxLoRAEncoder, LoRALayerBlock
from ..models.tiler import FastTileWorker
from ..models.nexus_gen import NexusGenAutoregressiveModel
from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher, FluxLoRAFuser
@@ -94,6 +96,9 @@ class FluxImagePipeline(BasePipeline):
self.ipadapter_image_encoder = None
self.qwenvl = None
self.step1x_connector: Qwen2Connector = None
self.nexus_gen: NexusGenAutoregressiveModel = None
self.nexus_gen_generation_adapter: NexusGenAdapter = None
self.nexus_gen_editing_adapter: NexusGenImageEmbeddingMerger = None
self.value_controller: MultiValueEncoder = None
self.infinityou_processor: InfinitYou = None
self.image_proj_model: InfiniteYouImageProjector = None
@@ -113,6 +118,7 @@ class FluxImagePipeline(BasePipeline):
FluxImageUnit_ControlNet(),
FluxImageUnit_IPAdapter(),
FluxImageUnit_EntityControl(),
FluxImageUnit_NexusGen(),
FluxImageUnit_TeaCache(),
FluxImageUnit_Flex(),
FluxImageUnit_Step1x(),
@@ -397,6 +403,9 @@ class FluxImagePipeline(BasePipeline):
pipe.infinityou_processor = InfinitYou(device=device)
pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher")
pipe.lora_encoder = model_manager.fetch_model("flux_lora_encoder")
pipe.nexus_gen = model_manager.fetch_model("nexus_gen_llm")
pipe.nexus_gen_generation_adapter = model_manager.fetch_model("nexus_gen_generation_adapter")
pipe.nexus_gen_editing_adapter = model_manager.fetch_model("nexus_gen_editing_adapter")
# ControlNet
controlnets = []
@@ -468,6 +477,8 @@ class FluxImagePipeline(BasePipeline):
value_controller_inputs: Union[list[float], float] = None,
# Step1x
step1x_reference_image: Image.Image = None,
# NexusGen
nexus_gen_reference_image: Image.Image = None,
# LoRA Encoder
lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None,
lora_encoder_scale: float = 1.0,
@@ -504,6 +515,7 @@ class FluxImagePipeline(BasePipeline):
"flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop,
"value_controller_inputs": value_controller_inputs,
"step1x_reference_image": step1x_reference_image,
"nexus_gen_reference_image": nexus_gen_reference_image,
"lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale,
"tea_cache_l1_thresh": tea_cache_l1_thresh,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
@@ -764,6 +776,60 @@ class FluxImageUnit_EntityControl(PipelineUnit):
return inputs_shared, inputs_posi, inputs_nega
class FluxImageUnit_NexusGen(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"),
)
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
if pipe.nexus_gen is None:
return inputs_shared, inputs_posi, inputs_nega
pipe.load_models_to_device(self.onload_model_names)
if inputs_shared.get("nexus_gen_reference_image", None) is None:
assert pipe.nexus_gen_generation_adapter is not None, "NexusGen requires a generation adapter to be set."
embed = pipe.nexus_gen(inputs_posi["prompt"])[0].unsqueeze(0)
inputs_posi["prompt_emb"] = pipe.nexus_gen_generation_adapter(embed)
inputs_posi['text_ids'] = torch.zeros(embed.shape[0], embed.shape[1], 3).to(device=pipe.device, dtype=pipe.torch_dtype)
else:
assert pipe.nexus_gen_editing_adapter is not None, "NexusGen requires an editing adapter to be set."
embed, ref_embed, grids = pipe.nexus_gen(inputs_posi["prompt"], inputs_shared["nexus_gen_reference_image"])
embeds_grid = grids[1:2].to(device=pipe.device, dtype=torch.long)
ref_embeds_grid = grids[0:1].to(device=pipe.device, dtype=torch.long)
inputs_posi["prompt_emb"] = pipe.nexus_gen_editing_adapter(embed.unsqueeze(0), embeds_grid, ref_embed.unsqueeze(0), ref_embeds_grid)
inputs_posi["text_ids"] = self.get_editing_text_ids(
inputs_shared["latents"],
embeds_grid[0][1].item(), embeds_grid[0][2].item(),
ref_embeds_grid[0][1].item(), ref_embeds_grid[0][2].item(),
)
return inputs_shared, inputs_posi, inputs_nega
def get_editing_text_ids(self, latents, target_embed_height, target_embed_width, ref_embed_height, ref_embed_width):
# prepare text ids for target and reference embeddings
batch_size, height, width = latents.shape[0], target_embed_height, target_embed_width
embed_ids = torch.zeros(height // 2, width // 2, 3)
scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width
embed_ids[..., 1] = embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height
embed_ids[..., 2] = embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width
embed_ids = embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3)
embed_text_ids = embed_ids.to(device=latents.device, dtype=latents.dtype)
batch_size, height, width = latents.shape[0], ref_embed_height, ref_embed_width
ref_embed_ids = torch.zeros(height // 2, width // 2, 3)
scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width
ref_embed_ids[..., 0] = ref_embed_ids[..., 0] + 1.0
ref_embed_ids[..., 1] = ref_embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height
ref_embed_ids[..., 2] = ref_embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width
ref_embed_ids = ref_embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3)
ref_embed_text_ids = ref_embed_ids.to(device=latents.device, dtype=latents.dtype)
text_ids = torch.cat([embed_text_ids, ref_embed_text_ids], dim=1)
return text_ids
class FluxImageUnit_Step1x(PipelineUnit):
def __init__(self):
super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder"))