diffusion skills framework

This commit is contained in:
Artiprocher
2026-03-17 13:34:25 +08:00
parent 7a80f10fa4
commit f88b99cb4f
11 changed files with 422 additions and 138 deletions

View File

@@ -0,0 +1,56 @@
from diffsynth.diffusion.skills import SkillsPipeline
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
from PIL import Image
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
skills = SkillsPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="DiffSynth-Studio/F2KB4B-Skills-ControlNet"),
ModelConfig(model_id="DiffSynth-Studio/F2KB4B-Skills-Brightness"),
],
)
skill_cache = skills(
positive_inputs = [
{
"model_id": 0,
"image": Image.open("xxx.jpg"),
"prompt": "一位长发少女,四周环绕着魔法粒子",
},
{
"model_id": 1,
"scale": 0.6,
},
],
negative_inputs = [
{
"model_id": 0,
"image": Image.open("xxx.jpg"),
"prompt": "一位长发少女,四周环绕着魔法粒子",
},
{
"model_id": 1,
"scale": 0.5,
},
],
pipe=pipe,
)
image = pipe(
prompt="一位长发少女,四周环绕着魔法粒子",
seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4,
height=1024, width=1024,
**skill_cache,
)
image.save("image.jpg")

View File

@@ -0,0 +1,16 @@
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path /mnt/nas1/duanzhongjie.dzj/dataset/ImagePulseV2 \
--dataset_metadata_path /mnt/nas1/duanzhongjie.dzj/dataset/ImagePulseV2/metadata_example_ti2ti.jsonl \
--extra_inputs "skill_inputs" \
--max_pixels 1048576 \
--dataset_repeat 1 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
--skill_model_id_or_path "models/base" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
--learning_rate 1e-4 \
--num_epochs 999 \
--remove_prefix_in_ckpt "pipe.skill_model." \
--output_path "./models/train/FLUX.2-klein-base-4B-skills_full" \
--trainable_models "skill_model" \
--use_gradient_checkpointing \
--save_steps 200

View File

@@ -0,0 +1,60 @@
from diffsynth import load_state_dict
from safetensors.torch import save_file
import torch
def Flux2DiTStateDictConverter(state_dict):
rename_dict = {
"time_guidance_embed.timestep_embedder.linear_1.weight": "time_guidance_embed.timestep_embedder.0.weight",
"time_guidance_embed.timestep_embedder.linear_2.weight": "time_guidance_embed.timestep_embedder.2.weight",
"x_embedder.weight": "img_embedder.weight",
"context_embedder.weight": "txt_embedder.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
state_dict_[rename_dict[name]] = state_dict[name]
elif name.startswith("transformer_blocks"):
if name.endswith("attn.to_q.weight"):
state_dict_[name.replace("to_q", "img_to_qkv").replace(".attn.", ".")] = torch.concat([
state_dict[name.replace("to_q", "to_q")],
state_dict[name.replace("to_q", "to_k")],
state_dict[name.replace("to_q", "to_v")],
], dim=0)
elif name.endswith("attn.to_k.weight") or name.endswith("attn.to_v.weight"):
continue
elif name.endswith("attn.to_out.0.weight"):
state_dict_[name.replace("attn.to_out.0.weight", "img_to_out.weight")] = state_dict[name]
elif name.endswith("attn.norm_q.weight"):
state_dict_[name.replace("attn.norm_q.weight", "img_norm_q.weight")] = state_dict[name]
elif name.endswith("attn.norm_k.weight"):
state_dict_[name.replace("attn.norm_k.weight", "img_norm_k.weight")] = state_dict[name]
elif name.endswith("attn.norm_added_q.weight"):
state_dict_[name.replace("attn.norm_added_q.weight", "txt_norm_q.weight")] = state_dict[name]
elif name.endswith("attn.norm_added_k.weight"):
state_dict_[name.replace("attn.norm_added_k.weight", "txt_norm_k.weight")] = state_dict[name]
elif name.endswith("attn.to_add_out.weight"):
state_dict_[name.replace("attn.to_add_out.weight", "txt_to_out.weight")] = state_dict[name]
elif name.endswith("attn.add_q_proj.weight"):
state_dict_[name.replace("add_q_proj", "txt_to_qkv").replace(".attn.", ".")] = torch.concat([
state_dict[name.replace("add_q_proj", "add_q_proj")],
state_dict[name.replace("add_q_proj", "add_k_proj")],
state_dict[name.replace("add_q_proj", "add_v_proj")],
], dim=0)
elif ".ff." in name:
state_dict_[name.replace(".ff.", ".img_ff.")] = state_dict[name]
elif ".ff_context." in name:
state_dict_[name.replace(".ff_context.", ".txt_ff.")] = state_dict[name]
elif name.endswith("attn.add_k_proj.weight") or name.endswith("attn.add_v_proj.weight"):
continue
else:
state_dict_[name] = state_dict[name]
elif name.startswith("single_transformer_blocks"):
state_dict_[name.replace(".attn.", ".")] = state_dict[name]
else:
state_dict_[name] = state_dict[name]
return state_dict_
state_dict = load_state_dict("xxx.safetensors")
save_file(state_dict, "yyy.safetensors")

View File

@@ -18,6 +18,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule):
extra_inputs=None,
fp8_models=None,
offload_models=None,
skill_model_id_or_path=None,
device="cpu",
task="sft",
):
@@ -26,6 +27,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule):
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"))
self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
self.pipe = self.load_training_skill_model(self.pipe, skill_model_id_or_path)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# Training mode
@@ -126,6 +128,7 @@ if __name__ == "__main__":
extra_inputs=args.extra_inputs,
fp8_models=args.fp8_models,
offload_models=args.offload_models,
skill_model_id_or_path=args.skill_model_id_or_path,
task=args.task,
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
)