mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
flux_eligen_refactor
This commit is contained in:
@@ -14,9 +14,10 @@ from typing_extensions import Literal
|
||||
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from ..prompters import FluxPrompter
|
||||
from ..models import ModelManager, SD3TextEncoder1, FluxTextEncoder2, FluxDiT, FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEncoder2, FluxDiT, FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.tiler import FastTileWorker
|
||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader
|
||||
|
||||
|
||||
|
||||
@@ -44,12 +45,15 @@ class FluxImagePipeline(BasePipeline):
|
||||
FluxImageUnit_ImageIDs(),
|
||||
FluxImageUnit_EmbeddedGuidanceEmbedder(),
|
||||
FluxImageUnit_IPAdapter(),
|
||||
FluxImageUnit_EntityControl(),
|
||||
]
|
||||
self.model_fn = model_fn_flux_image
|
||||
|
||||
|
||||
def load_lora(self, module, path, alpha=1):
|
||||
pass
|
||||
loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
@@ -321,6 +325,55 @@ class FluxImageUnit_IPAdapter(PipelineUnit):
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class FluxImageUnit_EntityControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("text_encoder_1", "text_encoder_2")
|
||||
)
|
||||
|
||||
def preprocess_masks(self, pipe, masks, height, width, dim):
|
||||
out_masks = []
|
||||
for mask in masks:
|
||||
mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
|
||||
mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
out_masks.append(mask)
|
||||
return out_masks
|
||||
|
||||
def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height, t5_sequence_length=512):
|
||||
entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1)
|
||||
entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
|
||||
|
||||
prompt_emb, _, _ = pipe.prompter.encode_prompt(
|
||||
entity_prompts, device=pipe.device, t5_sequence_length=t5_sequence_length
|
||||
)
|
||||
return prompt_emb.unsqueeze(0), entity_masks
|
||||
|
||||
def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_on_negative, cfg_scale):
|
||||
entity_prompt_emb_posi, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length)
|
||||
if enable_eligen_on_negative and cfg_scale != 1.0:
|
||||
entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1)
|
||||
entity_masks_nega = entity_masks_posi
|
||||
else:
|
||||
entity_prompt_emb_nega, entity_masks_nega = None, None
|
||||
eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi}
|
||||
eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega}
|
||||
return eligen_kwargs_posi, eligen_kwargs_nega
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None)
|
||||
if eligen_entity_prompts is None or eligen_entity_masks is None:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,
|
||||
eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"],
|
||||
inputs_shared["t5_sequence_length"], inputs_shared["eligen_enable_on_negative"], inputs_shared["cfg_scale"])
|
||||
inputs_posi.update(eligen_kwargs_posi)
|
||||
if inputs_shared.get("cfg_scale", 1.0) != 1.0:
|
||||
inputs_nega.update(eligen_kwargs_nega)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
Reference in New Issue
Block a user