mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
tmp commit for nexus-gen edit
This commit is contained in:
@@ -22,6 +22,8 @@ from ..models.flux_value_control import MultiValueEncoder
|
||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||
from ..models.flux_lora_encoder import FluxLoRAEncoder, LoRALayerBlock
|
||||
from ..models.tiler import FastTileWorker
|
||||
from ..models.nexus_gen import NexusGenAutoregressiveModel
|
||||
from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger
|
||||
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher, FluxLoRAFuser
|
||||
|
||||
@@ -94,6 +96,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.ipadapter_image_encoder = None
|
||||
self.qwenvl = None
|
||||
self.step1x_connector: Qwen2Connector = None
|
||||
self.nexus_gen: NexusGenAutoregressiveModel = None
|
||||
self.nexus_gen_generation_adapter: NexusGenAdapter = None
|
||||
self.nexus_gen_editing_adapter: NexusGenImageEmbeddingMerger = None
|
||||
self.value_controller: MultiValueEncoder = None
|
||||
self.infinityou_processor: InfinitYou = None
|
||||
self.image_proj_model: InfiniteYouImageProjector = None
|
||||
@@ -113,6 +118,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
FluxImageUnit_ControlNet(),
|
||||
FluxImageUnit_IPAdapter(),
|
||||
FluxImageUnit_EntityControl(),
|
||||
FluxImageUnit_NexusGen(),
|
||||
FluxImageUnit_TeaCache(),
|
||||
FluxImageUnit_Flex(),
|
||||
FluxImageUnit_Step1x(),
|
||||
@@ -397,6 +403,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
pipe.infinityou_processor = InfinitYou(device=device)
|
||||
pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher")
|
||||
pipe.lora_encoder = model_manager.fetch_model("flux_lora_encoder")
|
||||
pipe.nexus_gen = model_manager.fetch_model("nexus_gen_llm")
|
||||
pipe.nexus_gen_generation_adapter = model_manager.fetch_model("nexus_gen_generation_adapter")
|
||||
pipe.nexus_gen_editing_adapter = model_manager.fetch_model("nexus_gen_editing_adapter")
|
||||
|
||||
# ControlNet
|
||||
controlnets = []
|
||||
@@ -468,6 +477,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
value_controller_inputs: Union[list[float], float] = None,
|
||||
# Step1x
|
||||
step1x_reference_image: Image.Image = None,
|
||||
# NexusGen
|
||||
nexus_gen_reference_image: Image.Image = None,
|
||||
# LoRA Encoder
|
||||
lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None,
|
||||
lora_encoder_scale: float = 1.0,
|
||||
@@ -504,6 +515,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
"flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop,
|
||||
"value_controller_inputs": value_controller_inputs,
|
||||
"step1x_reference_image": step1x_reference_image,
|
||||
"nexus_gen_reference_image": nexus_gen_reference_image,
|
||||
"lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale,
|
||||
"tea_cache_l1_thresh": tea_cache_l1_thresh,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
@@ -764,6 +776,60 @@ class FluxImageUnit_EntityControl(PipelineUnit):
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class FluxImageUnit_NexusGen(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"),
|
||||
)
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if pipe.nexus_gen is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if inputs_shared.get("nexus_gen_reference_image", None) is None:
|
||||
assert pipe.nexus_gen_generation_adapter is not None, "NexusGen requires a generation adapter to be set."
|
||||
embed = pipe.nexus_gen(inputs_posi["prompt"])[0].unsqueeze(0)
|
||||
inputs_posi["prompt_emb"] = pipe.nexus_gen_generation_adapter(embed)
|
||||
inputs_posi['text_ids'] = torch.zeros(embed.shape[0], embed.shape[1], 3).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
else:
|
||||
assert pipe.nexus_gen_editing_adapter is not None, "NexusGen requires an editing adapter to be set."
|
||||
embed, ref_embed, grids = pipe.nexus_gen(inputs_posi["prompt"], inputs_shared["nexus_gen_reference_image"])
|
||||
embeds_grid = grids[1:2].to(device=pipe.device, dtype=torch.long)
|
||||
ref_embeds_grid = grids[0:1].to(device=pipe.device, dtype=torch.long)
|
||||
|
||||
inputs_posi["prompt_emb"] = pipe.nexus_gen_editing_adapter(embed.unsqueeze(0), embeds_grid, ref_embed.unsqueeze(0), ref_embeds_grid)
|
||||
inputs_posi["text_ids"] = self.get_editing_text_ids(
|
||||
inputs_shared["latents"],
|
||||
embeds_grid[0][1].item(), embeds_grid[0][2].item(),
|
||||
ref_embeds_grid[0][1].item(), ref_embeds_grid[0][2].item(),
|
||||
)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
def get_editing_text_ids(self, latents, target_embed_height, target_embed_width, ref_embed_height, ref_embed_width):
|
||||
# prepare text ids for target and reference embeddings
|
||||
batch_size, height, width = latents.shape[0], target_embed_height, target_embed_width
|
||||
embed_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width
|
||||
embed_ids[..., 1] = embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height
|
||||
embed_ids[..., 2] = embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width
|
||||
embed_ids = embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3)
|
||||
embed_text_ids = embed_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
batch_size, height, width = latents.shape[0], ref_embed_height, ref_embed_width
|
||||
ref_embed_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width
|
||||
ref_embed_ids[..., 0] = ref_embed_ids[..., 0] + 1.0
|
||||
ref_embed_ids[..., 1] = ref_embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height
|
||||
ref_embed_ids[..., 2] = ref_embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width
|
||||
ref_embed_ids = ref_embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3)
|
||||
ref_embed_text_ids = ref_embed_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
text_ids = torch.cat([embed_text_ids, ref_embed_text_ids], dim=1)
|
||||
return text_ids
|
||||
|
||||
|
||||
class FluxImageUnit_Step1x(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder"))
|
||||
|
||||
Reference in New Issue
Block a user