support customized lora forward

This commit is contained in:
lzw478614@alibaba-inc.com
2025-03-25 11:32:09 +08:00
parent 3dc28f428f
commit 04260801a2
4 changed files with 389 additions and 106 deletions

View File

@@ -13,7 +13,7 @@ from transformers import SiglipVisionModel
from copy import deepcopy
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
from ..models.flux_dit import RMSNorm
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
from ..vram_management import enable_vram_management, enable_auto_lora, AutoLoRALinear, AutoWrappedModule, AutoWrappedLinear
class FluxImagePipeline(BasePipeline):
@@ -132,6 +132,15 @@ class FluxImagePipeline(BasePipeline):
)
self.enable_cpu_offload()
def enable_auto_lora(self):
enable_auto_lora(
self.dit,
module_map={
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoLoRALinear,
},
name_prefix=''
)
def denoising_model(self):
return self.dit
@@ -391,6 +400,8 @@ class FluxImagePipeline(BasePipeline):
# Progress bar
progress_bar_cmd=tqdm,
progress_bar_st=None,
lora_state_dicts=[],
lora_alpahs=[]
):
height, width = self.check_resize_height_width(height, width)
@@ -430,6 +441,8 @@ class FluxImagePipeline(BasePipeline):
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
lora_state_dicts=lora_state_dicts,
lora_alpahs = lora_alpahs,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
)
noise_pred_posi = self.control_noise_via_local_prompts(
@@ -447,6 +460,8 @@ class FluxImagePipeline(BasePipeline):
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
lora_state_dicts=lora_state_dicts,
lora_alpahs = lora_alpahs,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
@@ -511,7 +526,6 @@ class TeaCache:
hidden_states = hidden_states + self.previous_residual
return hidden_states
def lets_dance_flux(
dit: FluxDiT,
controlnet: FluxMultiControlNetManager = None,
@@ -532,6 +546,7 @@ def lets_dance_flux(
tea_cache: TeaCache = None,
**kwargs
):
if tiled:
def flux_forward_fn(hl, hr, wl, wr):
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
@@ -613,7 +628,8 @@ def lets_dance_flux(
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None),
**kwargs
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
@@ -629,7 +645,8 @@ def lets_dance_flux(
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
**kwargs
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
@@ -639,8 +656,8 @@ def lets_dance_flux(
if tea_cache is not None:
tea_cache.store(hidden_states)
hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states)
hidden_states = dit.final_norm_out(hidden_states, conditioning, **kwargs)
hidden_states = dit.final_proj_out(hidden_states, **kwargs)
hidden_states = dit.unpatchify(hidden_states, height, width)
return hidden_states