Support LoRA encoder (#695)

* lora_encoder
This commit is contained in:
Zhongjie Duan
2025-07-19 20:44:03 +08:00
committed by GitHub
parent d19fcc8c04
commit 1384de0353
7 changed files with 294 additions and 2 deletions

View File

@@ -20,6 +20,7 @@ from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter
from ..models.flux_value_control import MultiValueEncoder
from ..models.flux_infiniteyou import InfiniteYouImageProjector
from ..models.flux_lora_encoder import FluxLoRAEncoder, LoRALayerBlock
from ..models.tiler import FastTileWorker
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher
@@ -97,6 +98,7 @@ class FluxImagePipeline(BasePipeline):
self.infinityou_processor: InfinitYou = None
self.image_proj_model: InfiniteYouImageProjector = None
self.lora_patcher: FluxLoraPatcher = None
self.lora_encoder: FluxLoRAEncoder = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher")
self.units = [
@@ -115,6 +117,7 @@ class FluxImagePipeline(BasePipeline):
FluxImageUnit_Flex(),
FluxImageUnit_Step1x(),
FluxImageUnit_ValueControl(),
FluxImageUnit_LoRAEncode(),
]
self.model_fn = model_fn_flux_image
@@ -196,6 +199,7 @@ class FluxImagePipeline(BasePipeline):
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
LoRALayerBlock: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
@@ -207,6 +211,33 @@ class FluxImagePipeline(BasePipeline):
),
vram_limit=vram_limit,
)
def enable_lora_magic(self):
if self.dit is not None:
if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device=self.device,
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=None,
)
if self.lora_patcher is not None:
for name, module in self.dit.named_modules():
if isinstance(module, AutoWrappedLinear):
merger_name = name.replace(".", "___")
if merger_name in self.lora_patcher.model_dict:
module.lora_merger = self.lora_patcher.model_dict[merger_name]
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
@@ -219,7 +250,7 @@ class FluxImagePipeline(BasePipeline):
vram_limit = vram_limit - vram_buffer
# Default config
default_vram_management_models = ["text_encoder_1", "vae_decoder", "vae_encoder", "controlnet", "image_proj_model", "ipadapter", "lora_patcher", "value_controller", "step1x_connector"]
default_vram_management_models = ["text_encoder_1", "vae_decoder", "vae_encoder", "controlnet", "image_proj_model", "ipadapter", "lora_patcher", "value_controller", "step1x_connector", "lora_encoder"]
for model_name in default_vram_management_models:
self._enable_vram_management_with_default_config(getattr(self, model_name), vram_limit)
@@ -366,6 +397,7 @@ class FluxImagePipeline(BasePipeline):
if pipe.image_proj_model is not None:
pipe.infinityou_processor = InfinitYou(device=device)
pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher")
pipe.lora_encoder = model_manager.fetch_model("flux_lora_encoder")
# ControlNet
controlnets = []
@@ -437,6 +469,9 @@ class FluxImagePipeline(BasePipeline):
value_controller_inputs: list[float] = None,
# Step1x
step1x_reference_image: Image.Image = None,
# LoRA Encoder
lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None,
lora_encoder_scale: float = 1.0,
# TeaCache
tea_cache_l1_thresh: float = None,
# Tile
@@ -470,6 +505,7 @@ class FluxImagePipeline(BasePipeline):
"flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop,
"value_controller_inputs": value_controller_inputs,
"step1x_reference_image": step1x_reference_image,
"lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale,
"tea_cache_l1_thresh": tea_cache_l1_thresh,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"progress_bar_cmd": progress_bar_cmd,
@@ -884,6 +920,66 @@ class InfinitYou(torch.nn.Module):
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}
class FluxImageUnit_LoRAEncode(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
onload_model_names=("lora_encoder",)
)
def parse_lora_encoder_inputs(self, lora_encoder_inputs):
if not isinstance(lora_encoder_inputs, list):
lora_encoder_inputs = [lora_encoder_inputs]
lora_configs = []
for lora_encoder_input in lora_encoder_inputs:
if isinstance(lora_encoder_input, str):
lora_encoder_input = ModelConfig(path=lora_encoder_input)
lora_encoder_input.download_if_necessary()
lora_configs.append(lora_encoder_input)
return lora_configs
def load_lora(self, lora_config, dtype, device):
loader = FluxLoRALoader(torch_dtype=dtype, device=device)
lora = load_state_dict(lora_config.path, torch_dtype=dtype, device=device)
lora = loader.convert_state_dict(lora)
return lora
def lora_embedding(self, pipe, lora_encoder_inputs):
lora_emb = []
for lora_config in self.parse_lora_encoder_inputs(lora_encoder_inputs):
lora = self.load_lora(lora_config, pipe.torch_dtype, pipe.device)
lora_emb.append(pipe.lora_encoder(lora))
lora_emb = torch.concat(lora_emb, dim=1)
return lora_emb
def add_to_text_embedding(self, prompt_emb, text_ids, lora_emb):
prompt_emb = torch.concat([prompt_emb, lora_emb], dim=1)
extra_text_ids = torch.zeros((lora_emb.shape[0], lora_emb.shape[1], 3), device=lora_emb.device, dtype=lora_emb.dtype)
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
return prompt_emb, text_ids
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("lora_encoder_inputs", None) is None:
return inputs_shared, inputs_posi, inputs_nega
# Encode
pipe.load_models_to_device(["lora_encoder"])
lora_encoder_inputs = inputs_shared["lora_encoder_inputs"]
lora_emb = self.lora_embedding(pipe, lora_encoder_inputs)
# Scale
lora_encoder_scale = inputs_shared.get("lora_encoder_scale", None)
if lora_encoder_scale is not None:
lora_emb = lora_emb * lora_encoder_scale
# Add to prompt embedding
inputs_posi["prompt_emb"], inputs_posi["text_ids"] = self.add_to_text_embedding(
inputs_posi["prompt_emb"], inputs_posi["text_ids"], lora_emb)
return inputs_shared, inputs_posi, inputs_nega
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh):
self.num_inference_steps = num_inference_steps

View File

@@ -165,6 +165,7 @@ class ModelConfig:
download_resource: str = "ModelScope"
offload_device: Optional[Union[str, torch.device]] = None
offload_dtype: Optional[torch.dtype] = None
skip_download: bool = False
def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False):
if self.path is None:
@@ -190,6 +191,7 @@ class ModelConfig:
is_folder = False
# Download
skip_download = skip_download or self.skip_download
if not skip_download:
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
snapshot_download(