diffusion skills framework

This commit is contained in:
Artiprocher
2026-03-17 13:34:25 +08:00
parent 7a80f10fa4
commit f88b99cb4f
11 changed files with 422 additions and 138 deletions

View File

@@ -93,6 +93,9 @@ class Flux2ImagePipeline(BasePipeline):
initial_noise: torch.Tensor = None,
# Steps
num_inference_steps: int = 30,
# KV Cache
skill_cache = None,
negative_skill_cache = None,
# Progress bar
progress_bar_cmd = tqdm,
):
@@ -101,9 +104,11 @@ class Flux2ImagePipeline(BasePipeline):
# Parameters
inputs_posi = {
"prompt": prompt,
"skill_cache": skill_cache,
}
inputs_nega = {
"negative_prompt": negative_prompt,
"skill_cache": negative_skill_cache,
}
inputs_shared = {
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
@@ -570,6 +575,7 @@ def model_fn_flux2(
image_ids=None,
edit_latents=None,
edit_image_ids=None,
skill_cache=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
@@ -587,6 +593,7 @@ def model_fn_flux2(
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=image_ids,
kv_cache=skill_cache,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)