From c0889c2564c4ce8881da0c85bf2c64b1b2680327 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 13 Jan 2025 15:56:33 +0800 Subject: [PATCH] support teacache --- diffsynth/pipelines/flux_image.py | 120 +++++++++++++++++++++-------- examples/TeaCache/README.md | 16 ++++ examples/TeaCache/flux_teacache.py | 15 ++++ 3 files changed, 121 insertions(+), 30 deletions(-) create mode 100644 examples/TeaCache/README.md create mode 100644 examples/TeaCache/flux_teacache.py diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index eb834e7..8cd009f 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -280,6 +280,8 @@ class FluxImagePipeline(BasePipeline): eligen_entity_masks=None, enable_eligen_on_negative=False, enable_eligen_inpaint=False, + # TeaCache + tea_cache_l1_thresh=None, # Tile tiled=False, tile_size=128, @@ -314,6 +316,9 @@ class FluxImagePipeline(BasePipeline): # ControlNets controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative) + # TeaCache + tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None} + # Denoise self.load_models_to_device(['dit', 'controlnet']) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): @@ -323,7 +328,7 @@ class FluxImagePipeline(BasePipeline): inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, - **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, + **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, ) noise_pred_posi = self.control_noise_via_local_prompts( prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, @@ -362,6 +367,48 @@ class FluxImagePipeline(BasePipeline): return image +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + def check(self, dit: FluxDiT, hidden_states, conditioning): + inp = hidden_states.clone() + temb_ = conditioning.clone() + modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_) + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = hidden_states.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + def lets_dance_flux( dit: FluxDiT, @@ -380,6 +427,7 @@ def lets_dance_flux( entity_prompt_emb=None, entity_masks=None, ipadapter_kwargs_list={}, + tea_cache: TeaCache = None, **kwargs ): if tiled: @@ -446,36 +494,48 @@ def lets_dance_flux( image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) attention_mask = None - # Joint Blocks - for block_id, block in enumerate(dit.blocks): - hidden_states, prompt_emb = block( - hidden_states, - prompt_emb, - conditioning, - image_rotary_emb, - attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None) - ) - # ControlNet - if controlnet is not None and controlnet_frames is not None: - hidden_states = hidden_states + controlnet_res_stack[block_id] + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, hidden_states, conditioning) + else: + tea_cache_update = False - # Single Blocks - hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) - num_joint_blocks = len(dit.blocks) - for block_id, block in enumerate(dit.single_blocks): - hidden_states, prompt_emb = block( - hidden_states, - prompt_emb, - conditioning, - image_rotary_emb, - attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None) - ) - # ControlNet - if controlnet is not None and controlnet_frames is not None: - hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] - hidden_states = hidden_states[:, prompt_emb.shape[1]:] + if tea_cache_update: + hidden_states = tea_cache.update(hidden_states) + else: + # Joint Blocks + for block_id, block in enumerate(dit.blocks): + hidden_states, prompt_emb = block( + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None) + ) + # ControlNet + if controlnet is not None and controlnet_frames is not None: + hidden_states = hidden_states + controlnet_res_stack[block_id] + + # Single Blocks + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + num_joint_blocks = len(dit.blocks) + for block_id, block in enumerate(dit.single_blocks): + hidden_states, prompt_emb = block( + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None) + ) + # ControlNet + if controlnet is not None and controlnet_frames is not None: + hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] + hidden_states = hidden_states[:, prompt_emb.shape[1]:] + + if tea_cache is not None: + tea_cache.store(hidden_states) hidden_states = dit.final_norm_out(hidden_states, conditioning) hidden_states = dit.final_proj_out(hidden_states) diff --git a/examples/TeaCache/README.md b/examples/TeaCache/README.md new file mode 100644 index 0000000..6f15c30 --- /dev/null +++ b/examples/TeaCache/README.md @@ -0,0 +1,16 @@ +# TeaCache + +TeaCache ([Timestep Embedding Aware Cache](https://github.com/ali-vilab/TeaCache)) is a training-free caching approach that estimates and leverages the fluctuating differences among model outputs across timesteps, thereby accelerating the inference. + +## Examples + +We provide examples on FLUX.1-dev. See [./flux_teacache.py](./flux_teacache.py). + +Steps: 50 + +GPU: A100 + +|TeaCache is disabled|tea_cache_l1_thresh=0.2|tea_cache_l1_thresh=0.4|tea_cache_l1_thresh=0.6|tea_cache_l1_thresh=0.8| +|-|-|-|-|-| +|23s|13s|9s|6s|5s| +|![image_None](https://github.com/user-attachments/assets/2bf5187a-9693-44d3-9ebb-6c33cd15443f)|![image_0 2](https://github.com/user-attachments/assets/5532ba94-c7e2-446e-a9ba-1c68c0f63350)|![image_0 4](https://github.com/user-attachments/assets/4c57c50d-87cd-493b-8603-1da57ec3b70d)|![image_0 6](https://github.com/user-attachments/assets/1d95a3a9-71f9-4b1a-ad5f-a5ea8d52eca7)|![image_0 8](https://github.com/user-attachments/assets/d8cfdd74-8b45-4048-b1b7-ce480aa23fa1) diff --git a/examples/TeaCache/flux_teacache.py b/examples/TeaCache/flux_teacache.py new file mode 100644 index 0000000..b900654 --- /dev/null +++ b/examples/TeaCache/flux_teacache.py @@ -0,0 +1,15 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline + + +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." + +for tea_cache_l1_thresh in [None, 0.2, 0.4, 0.6, 0.8]: + image = pipe( + prompt=prompt, embedded_guidance=3.5, seed=0, + num_inference_steps=50, tea_cache_l1_thresh=tea_cache_l1_thresh + ) + image.save(f"image_{tea_cache_l1_thresh}.png")