mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 16:18:13 +00:00
diffusion skills framework
This commit is contained in:
@@ -0,0 +1,56 @@
|
||||
from diffsynth.diffusion.skills import SkillsPipeline
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
skills = SkillsPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/F2KB4B-Skills-ControlNet"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/F2KB4B-Skills-Brightness"),
|
||||
],
|
||||
)
|
||||
skill_cache = skills(
|
||||
positive_inputs = [
|
||||
{
|
||||
"model_id": 0,
|
||||
"image": Image.open("xxx.jpg"),
|
||||
"prompt": "一位长发少女,四周环绕着魔法粒子",
|
||||
},
|
||||
{
|
||||
"model_id": 1,
|
||||
"scale": 0.6,
|
||||
},
|
||||
],
|
||||
negative_inputs = [
|
||||
{
|
||||
"model_id": 0,
|
||||
"image": Image.open("xxx.jpg"),
|
||||
"prompt": "一位长发少女,四周环绕着魔法粒子",
|
||||
},
|
||||
{
|
||||
"model_id": 1,
|
||||
"scale": 0.5,
|
||||
},
|
||||
],
|
||||
pipe=pipe,
|
||||
)
|
||||
image = pipe(
|
||||
prompt="一位长发少女,四周环绕着魔法粒子",
|
||||
seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4,
|
||||
height=1024, width=1024,
|
||||
**skill_cache,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,16 @@
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path /mnt/nas1/duanzhongjie.dzj/dataset/ImagePulseV2 \
|
||||
--dataset_metadata_path /mnt/nas1/duanzhongjie.dzj/dataset/ImagePulseV2/metadata_example_ti2ti.jsonl \
|
||||
--extra_inputs "skill_inputs" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 1 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--skill_model_id_or_path "models/base" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 999 \
|
||||
--remove_prefix_in_ckpt "pipe.skill_model." \
|
||||
--output_path "./models/train/FLUX.2-klein-base-4B-skills_full" \
|
||||
--trainable_models "skill_model" \
|
||||
--use_gradient_checkpointing \
|
||||
--save_steps 200
|
||||
@@ -0,0 +1,60 @@
|
||||
from diffsynth import load_state_dict
|
||||
from safetensors.torch import save_file
|
||||
import torch
|
||||
|
||||
|
||||
def Flux2DiTStateDictConverter(state_dict):
|
||||
rename_dict = {
|
||||
"time_guidance_embed.timestep_embedder.linear_1.weight": "time_guidance_embed.timestep_embedder.0.weight",
|
||||
"time_guidance_embed.timestep_embedder.linear_2.weight": "time_guidance_embed.timestep_embedder.2.weight",
|
||||
"x_embedder.weight": "img_embedder.weight",
|
||||
"context_embedder.weight": "txt_embedder.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = state_dict[name]
|
||||
elif name.startswith("transformer_blocks"):
|
||||
if name.endswith("attn.to_q.weight"):
|
||||
state_dict_[name.replace("to_q", "img_to_qkv").replace(".attn.", ".")] = torch.concat([
|
||||
state_dict[name.replace("to_q", "to_q")],
|
||||
state_dict[name.replace("to_q", "to_k")],
|
||||
state_dict[name.replace("to_q", "to_v")],
|
||||
], dim=0)
|
||||
elif name.endswith("attn.to_k.weight") or name.endswith("attn.to_v.weight"):
|
||||
continue
|
||||
elif name.endswith("attn.to_out.0.weight"):
|
||||
state_dict_[name.replace("attn.to_out.0.weight", "img_to_out.weight")] = state_dict[name]
|
||||
elif name.endswith("attn.norm_q.weight"):
|
||||
state_dict_[name.replace("attn.norm_q.weight", "img_norm_q.weight")] = state_dict[name]
|
||||
elif name.endswith("attn.norm_k.weight"):
|
||||
state_dict_[name.replace("attn.norm_k.weight", "img_norm_k.weight")] = state_dict[name]
|
||||
elif name.endswith("attn.norm_added_q.weight"):
|
||||
state_dict_[name.replace("attn.norm_added_q.weight", "txt_norm_q.weight")] = state_dict[name]
|
||||
elif name.endswith("attn.norm_added_k.weight"):
|
||||
state_dict_[name.replace("attn.norm_added_k.weight", "txt_norm_k.weight")] = state_dict[name]
|
||||
elif name.endswith("attn.to_add_out.weight"):
|
||||
state_dict_[name.replace("attn.to_add_out.weight", "txt_to_out.weight")] = state_dict[name]
|
||||
elif name.endswith("attn.add_q_proj.weight"):
|
||||
state_dict_[name.replace("add_q_proj", "txt_to_qkv").replace(".attn.", ".")] = torch.concat([
|
||||
state_dict[name.replace("add_q_proj", "add_q_proj")],
|
||||
state_dict[name.replace("add_q_proj", "add_k_proj")],
|
||||
state_dict[name.replace("add_q_proj", "add_v_proj")],
|
||||
], dim=0)
|
||||
elif ".ff." in name:
|
||||
state_dict_[name.replace(".ff.", ".img_ff.")] = state_dict[name]
|
||||
elif ".ff_context." in name:
|
||||
state_dict_[name.replace(".ff_context.", ".txt_ff.")] = state_dict[name]
|
||||
elif name.endswith("attn.add_k_proj.weight") or name.endswith("attn.add_v_proj.weight"):
|
||||
continue
|
||||
else:
|
||||
state_dict_[name] = state_dict[name]
|
||||
elif name.startswith("single_transformer_blocks"):
|
||||
state_dict_[name.replace(".attn.", ".")] = state_dict[name]
|
||||
else:
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
state_dict = load_state_dict("xxx.safetensors")
|
||||
save_file(state_dict, "yyy.safetensors")
|
||||
@@ -18,6 +18,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule):
|
||||
extra_inputs=None,
|
||||
fp8_models=None,
|
||||
offload_models=None,
|
||||
skill_model_id_or_path=None,
|
||||
device="cpu",
|
||||
task="sft",
|
||||
):
|
||||
@@ -26,6 +27,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule):
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"))
|
||||
self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||
self.pipe = self.load_training_skill_model(self.pipe, skill_model_id_or_path)
|
||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||
|
||||
# Training mode
|
||||
@@ -126,6 +128,7 @@ if __name__ == "__main__":
|
||||
extra_inputs=args.extra_inputs,
|
||||
fp8_models=args.fp8_models,
|
||||
offload_models=args.offload_models,
|
||||
skill_model_id_or_path=args.skill_model_id_or_path,
|
||||
task=args.task,
|
||||
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user