support teacache

This commit is contained in:
Artiprocher
2025-01-13 15:56:33 +08:00
parent 913591c13e
commit c0889c2564
3 changed files with 121 additions and 30 deletions

View File

@@ -280,6 +280,8 @@ class FluxImagePipeline(BasePipeline):
eligen_entity_masks=None,
enable_eligen_on_negative=False,
enable_eligen_inpaint=False,
# TeaCache
tea_cache_l1_thresh=None,
# Tile
tiled=False,
tile_size=128,
@@ -314,6 +316,9 @@ class FluxImagePipeline(BasePipeline):
# ControlNets
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
# TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
# Denoise
self.load_models_to_device(['dit', 'controlnet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
@@ -323,7 +328,7 @@ class FluxImagePipeline(BasePipeline):
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -362,6 +367,48 @@ class FluxImagePipeline(BasePipeline):
return image
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
def check(self, dit: FluxDiT, hidden_states, conditioning):
inp = hidden_states.clone()
temb_ = conditioning.clone()
modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_)
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step += 1
if self.step == self.num_inference_steps:
self.step = 0
if should_calc:
self.previous_hidden_states = hidden_states.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
def lets_dance_flux(
dit: FluxDiT,
@@ -380,6 +427,7 @@ def lets_dance_flux(
entity_prompt_emb=None,
entity_masks=None,
ipadapter_kwargs_list={},
tea_cache: TeaCache = None,
**kwargs
):
if tiled:
@@ -446,36 +494,48 @@ def lets_dance_flux(
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
attention_mask = None
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
else:
tea_cache_update = False
# Single Blocks
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
if tea_cache_update:
hidden_states = tea_cache.update(hidden_states)
else:
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
# Single Blocks
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
if tea_cache is not None:
tea_cache.store(hidden_states)
hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states)