mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
support customized lora forward
This commit is contained in:
@@ -13,7 +13,7 @@ from transformers import SiglipVisionModel
|
||||
from copy import deepcopy
|
||||
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||
from ..models.flux_dit import RMSNorm
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
from ..vram_management import enable_vram_management, enable_auto_lora, AutoLoRALinear, AutoWrappedModule, AutoWrappedLinear
|
||||
|
||||
|
||||
class FluxImagePipeline(BasePipeline):
|
||||
@@ -132,6 +132,15 @@ class FluxImagePipeline(BasePipeline):
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
def enable_auto_lora(self):
|
||||
enable_auto_lora(
|
||||
self.dit,
|
||||
module_map={
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoLoRALinear,
|
||||
},
|
||||
name_prefix=''
|
||||
)
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
@@ -391,6 +400,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
# Progress bar
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
lora_state_dicts=[],
|
||||
lora_alpahs=[]
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
@@ -430,6 +441,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
lora_state_dicts=lora_state_dicts,
|
||||
lora_alpahs = lora_alpahs,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
@@ -447,6 +460,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
lora_state_dicts=lora_state_dicts,
|
||||
lora_alpahs = lora_alpahs,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
@@ -511,7 +526,6 @@ class TeaCache:
|
||||
hidden_states = hidden_states + self.previous_residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
def lets_dance_flux(
|
||||
dit: FluxDiT,
|
||||
controlnet: FluxMultiControlNetManager = None,
|
||||
@@ -532,6 +546,7 @@ def lets_dance_flux(
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if tiled:
|
||||
def flux_forward_fn(hl, hr, wl, wr):
|
||||
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
|
||||
@@ -613,7 +628,8 @@ def lets_dance_flux(
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None),
|
||||
**kwargs
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
@@ -629,7 +645,8 @@ def lets_dance_flux(
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
|
||||
**kwargs
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
@@ -639,8 +656,8 @@ def lets_dance_flux(
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(hidden_states)
|
||||
|
||||
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
||||
hidden_states = dit.final_proj_out(hidden_states)
|
||||
hidden_states = dit.final_norm_out(hidden_states, conditioning, **kwargs)
|
||||
hidden_states = dit.final_proj_out(hidden_states, **kwargs)
|
||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user