nexus-gen

This commit is contained in:
Artiprocher
2025-04-30 17:09:15 +08:00
parent ef2a7abad4
commit f7737aff98
9 changed files with 3200 additions and 7 deletions

View File

@@ -202,10 +202,10 @@ class FluxImagePipeline(BasePipeline):
return image
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
if self.text_encoder_1 is not None and self.text_encoder_2 is not None:
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512, image_emb=None):
if (self.text_encoder_1 is not None and self.text_encoder_2 is not None) or (image_emb is not None):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length, image_emb=image_emb
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
else:
@@ -358,13 +358,13 @@ class FluxImagePipeline(BasePipeline):
return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale, image_emb=None):
# Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length, image_emb=image_emb)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
@@ -432,6 +432,7 @@ class FluxImagePipeline(BasePipeline):
height=1024,
width=1024,
seed=None,
image_emb=None,
# Steps
num_inference_steps=30,
# local prompts
@@ -483,7 +484,7 @@ class FluxImagePipeline(BasePipeline):
latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
# Prompt
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale, image_emb)
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)

View File

@@ -59,6 +59,7 @@ class FluxPrompter(BasePrompter):
positive=True,
device="cuda",
t5_sequence_length=512,
image_emb=None,
):
prompt = self.process_prompt(prompt, positive=positive)
@@ -66,7 +67,10 @@ class FluxPrompter(BasePrompter):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
# T5
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
if image_emb is not None:
prompt_emb = image_emb
else:
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
# text_ids
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)