mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 02:38:10 +00:00
support teacache
This commit is contained in:
@@ -280,6 +280,8 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
eligen_entity_masks=None,
|
eligen_entity_masks=None,
|
||||||
enable_eligen_on_negative=False,
|
enable_eligen_on_negative=False,
|
||||||
enable_eligen_inpaint=False,
|
enable_eligen_inpaint=False,
|
||||||
|
# TeaCache
|
||||||
|
tea_cache_l1_thresh=None,
|
||||||
# Tile
|
# Tile
|
||||||
tiled=False,
|
tiled=False,
|
||||||
tile_size=128,
|
tile_size=128,
|
||||||
@@ -314,6 +316,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
# ControlNets
|
# ControlNets
|
||||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||||
|
|
||||||
|
# TeaCache
|
||||||
|
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
self.load_models_to_device(['dit', 'controlnet'])
|
self.load_models_to_device(['dit', 'controlnet'])
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
@@ -323,7 +328,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep,
|
||||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi,
|
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
|
||||||
)
|
)
|
||||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||||
@@ -362,6 +367,48 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class TeaCache:
|
||||||
|
def __init__(self, num_inference_steps, rel_l1_thresh):
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
self.step = 0
|
||||||
|
self.accumulated_rel_l1_distance = 0
|
||||||
|
self.previous_modulated_input = None
|
||||||
|
self.rel_l1_thresh = rel_l1_thresh
|
||||||
|
self.previous_residual = None
|
||||||
|
self.previous_hidden_states = None
|
||||||
|
|
||||||
|
def check(self, dit: FluxDiT, hidden_states, conditioning):
|
||||||
|
inp = hidden_states.clone()
|
||||||
|
temb_ = conditioning.clone()
|
||||||
|
modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_)
|
||||||
|
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||||
|
should_calc = True
|
||||||
|
self.accumulated_rel_l1_distance = 0
|
||||||
|
else:
|
||||||
|
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
|
||||||
|
rescale_func = np.poly1d(coefficients)
|
||||||
|
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||||
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||||
|
should_calc = False
|
||||||
|
else:
|
||||||
|
should_calc = True
|
||||||
|
self.accumulated_rel_l1_distance = 0
|
||||||
|
self.previous_modulated_input = modulated_inp
|
||||||
|
self.step += 1
|
||||||
|
if self.step == self.num_inference_steps:
|
||||||
|
self.step = 0
|
||||||
|
if should_calc:
|
||||||
|
self.previous_hidden_states = hidden_states.clone()
|
||||||
|
return not should_calc
|
||||||
|
|
||||||
|
def store(self, hidden_states):
|
||||||
|
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||||
|
self.previous_hidden_states = None
|
||||||
|
|
||||||
|
def update(self, hidden_states):
|
||||||
|
hidden_states = hidden_states + self.previous_residual
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def lets_dance_flux(
|
def lets_dance_flux(
|
||||||
dit: FluxDiT,
|
dit: FluxDiT,
|
||||||
@@ -380,6 +427,7 @@ def lets_dance_flux(
|
|||||||
entity_prompt_emb=None,
|
entity_prompt_emb=None,
|
||||||
entity_masks=None,
|
entity_masks=None,
|
||||||
ipadapter_kwargs_list={},
|
ipadapter_kwargs_list={},
|
||||||
|
tea_cache: TeaCache = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if tiled:
|
if tiled:
|
||||||
@@ -446,6 +494,15 @@ def lets_dance_flux(
|
|||||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
|
|
||||||
|
# TeaCache
|
||||||
|
if tea_cache is not None:
|
||||||
|
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
|
||||||
|
else:
|
||||||
|
tea_cache_update = False
|
||||||
|
|
||||||
|
if tea_cache_update:
|
||||||
|
hidden_states = tea_cache.update(hidden_states)
|
||||||
|
else:
|
||||||
# Joint Blocks
|
# Joint Blocks
|
||||||
for block_id, block in enumerate(dit.blocks):
|
for block_id, block in enumerate(dit.blocks):
|
||||||
hidden_states, prompt_emb = block(
|
hidden_states, prompt_emb = block(
|
||||||
@@ -477,6 +534,9 @@ def lets_dance_flux(
|
|||||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||||
|
|
||||||
|
if tea_cache is not None:
|
||||||
|
tea_cache.store(hidden_states)
|
||||||
|
|
||||||
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
||||||
hidden_states = dit.final_proj_out(hidden_states)
|
hidden_states = dit.final_proj_out(hidden_states)
|
||||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
hidden_states = dit.unpatchify(hidden_states, height, width)
|
||||||
|
|||||||
16
examples/TeaCache/README.md
Normal file
16
examples/TeaCache/README.md
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# TeaCache
|
||||||
|
|
||||||
|
TeaCache ([Timestep Embedding Aware Cache](https://github.com/ali-vilab/TeaCache)) is a training-free caching approach that estimates and leverages the fluctuating differences among model outputs across timesteps, thereby accelerating the inference.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
We provide examples on FLUX.1-dev. See [./flux_teacache.py](./flux_teacache.py).
|
||||||
|
|
||||||
|
Steps: 50
|
||||||
|
|
||||||
|
GPU: A100
|
||||||
|
|
||||||
|
|TeaCache is disabled|tea_cache_l1_thresh=0.2|tea_cache_l1_thresh=0.4|tea_cache_l1_thresh=0.6|tea_cache_l1_thresh=0.8|
|
||||||
|
|-|-|-|-|-|
|
||||||
|
|23s|13s|9s|6s|5s|
|
||||||
|
|||||
|
||||||
15
examples/TeaCache/flux_teacache.py
Normal file
15
examples/TeaCache/flux_teacache.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, FluxImagePipeline
|
||||||
|
|
||||||
|
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||||
|
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||||
|
|
||||||
|
for tea_cache_l1_thresh in [None, 0.2, 0.4, 0.6, 0.8]:
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt, embedded_guidance=3.5, seed=0,
|
||||||
|
num_inference_steps=50, tea_cache_l1_thresh=tea_cache_l1_thresh
|
||||||
|
)
|
||||||
|
image.save(f"image_{tea_cache_l1_thresh}.png")
|
||||||
Reference in New Issue
Block a user