Merge pull request #670 from modelscope/flux-any-training

support flux any training
This commit is contained in:
Zhongjie Duan
2025-07-08 21:45:02 +08:00
committed by GitHub
29 changed files with 575 additions and 13 deletions

View File

@@ -277,7 +277,7 @@ class FluxLoRAConverter:
pass
@staticmethod
def align_to_opensource_format(state_dict, alpha=1.0):
def align_to_opensource_format(state_dict, alpha=None):
prefix_rename_dict = {
"single_blocks": "lora_unet_single_blocks",
"blocks": "lora_unet_double_blocks",
@@ -316,7 +316,8 @@ class FluxLoRAConverter:
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
state_dict_[rename] = param
if rename.endswith("lora_up.weight"):
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
lora_alpha = alpha if alpha is not None else param.shape[-1]
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((lora_alpha,))[0]
return state_dict_
@staticmethod

View File

@@ -704,7 +704,8 @@ class FluxImageUnit_Step1x(PipelineUnit):
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
image = pipe.vae_encoder(image)
inputs_posi.update({"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image})
inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image})
if inputs_shared.get("cfg_scale", 1) != 1:
inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image})
return inputs_shared, inputs_posi, inputs_nega
@@ -727,6 +728,8 @@ class FluxImageUnit_Flex(PipelineUnit):
def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride):
if pipe.dit.input_dim == 196:
if flex_control_stop is None:
flex_control_stop = 1
pipe.load_models_to_device(self.onload_model_names)
if flex_inpaint_image is None:
flex_inpaint_image = torch.zeros_like(latents)
@@ -760,14 +763,15 @@ class FluxImageUnit_InfiniteYou(PipelineUnit):
def process(self, pipe: FluxImagePipeline, infinityou_id_image, infinityou_guidance):
if infinityou_id_image is not None:
return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance)
return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance, pipe.device)
else:
return {}
class InfinitYou:
class InfinitYou(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__()
from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis
self.device = device
@@ -791,16 +795,16 @@ class InfinitYou:
face_info = self.app_160.get(id_image_cv2)
return face_info
def extract_arcface_bgr_embedding(self, in_image, landmark):
def extract_arcface_bgr_embedding(self, in_image, landmark, device):
from insightface.utils import face_align
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
arc_face_image = 2 * arc_face_image - 1
arc_face_image = arc_face_image.contiguous().to(self.device)
arc_face_image = arc_face_image.contiguous().to(device=device, dtype=self.torch_dtype)
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
return face_emb
def prepare_infinite_you(self, model, id_image, infinityou_guidance):
def prepare_infinite_you(self, model, id_image, infinityou_guidance, device):
import cv2
if id_image is None:
return {'id_emb': None}
@@ -809,9 +813,9 @@ class InfinitYou:
if len(face_info) == 0:
raise ValueError('No face detected in the input ID image')
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark, device)
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=device, dtype=self.torch_dtype)
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}

View File

@@ -0,0 +1,12 @@
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 200 \
--model_id_with_origin_paths "ostris/Flex.2-preview:Flex.2-preview.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
--learning_rate 1e-5 \
--num_epochs 1 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLEX.2-preview_full" \
--trainable_models "dit" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,14 @@
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_inpaint.csv \
--data_file_keys "image,controlnet_image,controlnet_inpaint_mask" \
--max_pixels 1048576 \
--dataset_repeat 400 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta:diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 1 \
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
--output_path "./models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_full" \
--trainable_models "controlnet" \
--extra_inputs "controlnet_image,controlnet_inpaint_mask" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,14 @@
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 400 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-Controlnet-Union-alpha:diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 1 \
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
--output_path "./models/train/FLUX.1-dev-Controlnet-Union-alpha_full" \
--trainable_models "controlnet" \
--extra_inputs "controlnet_image,controlnet_processor_id" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,14 @@
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 400 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,jasperai/Flux.1-dev-Controlnet-Upscaler:diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 1 \
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
--output_path "./models/train/FLUX.1-dev-Controlnet-Upscaler_full" \
--trainable_models "controlnet" \
--extra_inputs "controlnet_image" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,14 @@
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_infiniteyou.csv \
--data_file_keys "image,controlnet_image,infinityou_id_image" \
--max_pixels 1048576 \
--dataset_repeat 400 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/image_proj_model.bin,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors" \
--learning_rate 1e-5 \
--num_epochs 1 \
--remove_prefix_in_ckpt "pipe." \
--output_path "./models/train/FLUX.1-dev-InfiniteYou_full" \
--trainable_models "controlnet,image_proj_model" \
--extra_inputs "controlnet_image,infinityou_id_image,infinityou_guidance" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,14 @@
accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_step1x.csv \
--data_file_keys "image,step1x_reference_image" \
--max_pixels 1048576 \
--dataset_repeat 400 \
--model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \
--learning_rate 1e-5 \
--num_epochs 1 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Step1X-Edit_full" \
--trainable_models "dit" \
--extra_inputs "step1x_reference_image" \
--use_gradient_checkpointing_offload

View File

@@ -0,0 +1,15 @@
accelerate launch examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "ostris/Flex.2-preview:Flex.2-preview.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLEX.2-preview_lora" \
--lora_base_model "dit" \
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
--lora_rank 32 \
--align_to_opensource_format \
--use_gradient_checkpointing

View File

@@ -0,0 +1,17 @@
accelerate launch examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_inpaint.csv \
--data_file_keys "image,controlnet_image,controlnet_inpaint_mask" \
--max_pixels 1048576 \
--dataset_repeat 100 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta:diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_lora" \
--lora_base_model "dit" \
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
--lora_rank 32 \
--extra_inputs "controlnet_image,controlnet_inpaint_mask" \
--align_to_opensource_format \
--use_gradient_checkpointing

View File

@@ -0,0 +1,17 @@
accelerate launch examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 100 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-Controlnet-Union-alpha:diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.1-dev-Controlnet-Union-alpha_lora" \
--lora_base_model "dit" \
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
--lora_rank 32 \
--extra_inputs "controlnet_image,controlnet_processor_id" \
--align_to_opensource_format \
--use_gradient_checkpointing

View File

@@ -0,0 +1,17 @@
accelerate launch examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 100 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,jasperai/Flux.1-dev-Controlnet-Upscaler:diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.1-dev-Controlnet-Upscaler_lora" \
--lora_base_model "dit" \
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
--lora_rank 32 \
--extra_inputs "controlnet_image" \
--align_to_opensource_format \
--use_gradient_checkpointing

View File

@@ -0,0 +1,17 @@
accelerate launch examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_ipadapter.csv \
--data_file_keys "image,ipadapter_images" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.1-dev-IP-Adapter_lora" \
--lora_base_model "dit" \
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
--lora_rank 32 \
--extra_inputs "ipadapter_images" \
--align_to_opensource_format \
--use_gradient_checkpointing

View File

@@ -0,0 +1,17 @@
accelerate launch examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_infiniteyou.csv \
--data_file_keys "image,controlnet_image,infinityou_id_image" \
--max_pixels 1048576 \
--dataset_repeat 100 \
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/image_proj_model.bin,ByteDance/InfiniteYou:infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/FLUX.1-dev-InfiniteYou_lora" \
--lora_base_model "dit" \
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
--lora_rank 32 \
--extra_inputs "controlnet_image,infinityou_id_image,infinityou_guidance" \
--align_to_opensource_format \
--use_gradient_checkpointing

View File

@@ -0,0 +1,17 @@
accelerate launch examples/flux/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_step1x.csv \
--data_file_keys "image,step1x_reference_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Step1X-Edit_lora" \
--lora_base_model "dit" \
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
--lora_rank 32 \
--extra_inputs "step1x_reference_image" \
--align_to_opensource_format \
--use_gradient_checkpointing

View File

@@ -1,5 +1,5 @@
import torch, os, json
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser
from diffsynth.models.lora import FluxLoRAConverter
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -51,7 +51,7 @@ class FluxTrainingModule(DiffusionTrainingModule):
def forward_preprocess(self, data):
# CFG-sensitive parameters
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {}
inputs_nega = {"negative_prompt": ""}
# CFG-unsensitive parameters
inputs_shared = {
@@ -72,8 +72,14 @@ class FluxTrainingModule(DiffusionTrainingModule):
}
# Extra inputs
controlnet_input = {}
for extra_input in self.extra_inputs:
inputs_shared[extra_input] = data[extra_input]
if extra_input.startswith("controlnet_"):
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
else:
inputs_shared[extra_input] = data[extra_input]
if len(controlnet_input) > 0:
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
# Pipeline units will automatically process the input parameters.
for unit in self.pipe.units:
@@ -100,6 +106,7 @@ if __name__ == "__main__":
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
)

View File

@@ -0,0 +1,20 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from diffsynth import load_state_dict
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
state_dict = load_state_dict("models/train/FLEX.2-preview_full/epoch-0.safetensors")
pipe.dit.load_state_dict(state_dict)
image = pipe(prompt="dog,white and brown dog, sitting on wall, under pink flowers", seed=0)
image.save("image_FLEX.2-preview_full.jpg")

View File

@@ -0,0 +1,31 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from diffsynth import load_state_dict
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors"),
],
)
state_dict = load_state_dict("models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_full/epoch-0.safetensors")
pipe.controlnet.models[0].load_state_dict(state_dict)
image = pipe(
prompt="a cat sitting on a chair, wearing sunglasses",
controlnet_inputs=[ControlNetInput(
image=Image.open("data/example_image_dataset/inpaint/image_1.jpg"),
inpaint_mask=Image.open("data/example_image_dataset/inpaint/mask.jpg"),
scale=0.9
)],
height=1024, width=1024,
seed=0, rand_device="cuda",
)
image.save("image_FLUX.1-dev-Controlnet-Inpainting-Beta_full.jpg")

View File

@@ -0,0 +1,31 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from diffsynth import load_state_dict
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors"),
],
)
state_dict = load_state_dict("models/train/FLUX.1-dev-Controlnet-Union-alpha_full/epoch-0.safetensors")
pipe.controlnet.models[0].load_state_dict(state_dict)
image = pipe(
prompt="a dog",
controlnet_inputs=[ControlNetInput(
image=Image.open("data/example_image_dataset/canny/image_1.jpg"),
scale=0.9,
processor_id="canny",
)],
height=768, width=768,
seed=0, rand_device="cuda",
)
image.save("image_FLUX.1-dev-Controlnet-Union-alpha_full.jpg")

View File

@@ -0,0 +1,30 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from diffsynth import load_state_dict
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors"),
],
)
state_dict = load_state_dict("models/train/FLUX.1-dev-Controlnet-Upscaler_full/epoch-0.safetensors")
pipe.controlnet.models[0].load_state_dict(state_dict)
image = pipe(
prompt="a dog",
controlnet_inputs=[ControlNetInput(
image=Image.open("data/example_image_dataset/upscale/image_1.jpg"),
scale=0.9
)],
height=768, width=768,
seed=0, rand_device="cuda",
)
image.save("image_FLUX.1-dev-Controlnet-Upscaler_full.jpg")

View File

@@ -0,0 +1,33 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from diffsynth import load_state_dict
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin"),
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors"),
],
)
state_dict = load_state_dict("models/train/FLUX.1-dev-InfiniteYou_full/epoch-0.safetensors")
state_dict_projector = {i.replace("image_proj_model.", ""): state_dict[i] for i in state_dict if i.startswith("image_proj_model.")}
pipe.image_proj_model.load_state_dict(state_dict_projector)
state_dict_controlnet = {i.replace("controlnet.models.0.", ""): state_dict[i] for i in state_dict if i.startswith("controlnet.models.0.")}
pipe.controlnet.models[0].load_state_dict(state_dict_controlnet)
image = pipe(
prompt="a man with a red hat",
controlnet_inputs=[ControlNetInput(
image=Image.open("data/example_image_dataset/infiniteyou/image_1.jpg"),
)],
height=1024, width=1024,
seed=0, rand_device="cuda",
)
image.save("image_FLUX.1-dev-InfiniteYou_full.jpg")

View File

@@ -0,0 +1,25 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from diffsynth import load_state_dict
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct"),
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"),
],
)
state_dict = load_state_dict("models/train/Step1X-Edit_full/epoch-0.safetensors")
pipe.dit.load_state_dict(state_dict)
image = pipe(
prompt="Make the dog turn its head around.",
step1x_reference_image=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)),
height=768, width=768, cfg_scale=6,
seed=0
)
image.save("image_Step1X-Edit_full.jpg")

View File

@@ -0,0 +1,18 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
pipe.load_lora(pipe.dit, "models/train/FLEX.2-preview_lora/epoch-4.safetensors", alpha=1)
image = pipe(prompt="dog,white and brown dog, sitting on wall, under pink flowers", seed=0)
image.save("image_FLEX.2-preview_lora.jpg")

View File

@@ -0,0 +1,29 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors"),
],
)
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-Controlnet-Inpainting-Beta_lora/epoch-4.safetensors", alpha=1)
image = pipe(
prompt="a cat sitting on a chair, wearing sunglasses",
controlnet_inputs=[ControlNetInput(
image=Image.open("data/example_image_dataset/inpaint/image_1.jpg"),
inpaint_mask=Image.open("data/example_image_dataset/inpaint/mask.jpg"),
scale=0.9
)],
height=1024, width=1024,
seed=0, rand_device="cuda",
)
image.save("image_FLUX.1-dev-Controlnet-Inpainting-Beta_lora.jpg")

View File

@@ -0,0 +1,29 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors"),
],
)
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-Controlnet-Union-alpha_lora/epoch-4.safetensors", alpha=1)
image = pipe(
prompt="a dog",
controlnet_inputs=[ControlNetInput(
image=Image.open("data/example_image_dataset/canny/image_1.jpg"),
scale=0.9,
processor_id="canny",
)],
height=768, width=768,
seed=0, rand_device="cuda",
)
image.save("image_FLUX.1-dev-Controlnet-Union-alpha_lora.jpg")

View File

@@ -0,0 +1,28 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors"),
],
)
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-Controlnet-Upscaler_lora/epoch-4.safetensors", alpha=1)
image = pipe(
prompt="a dog",
controlnet_inputs=[ControlNetInput(
image=Image.open("data/example_image_dataset/upscale/image_1.jpg"),
scale=0.9
)],
height=768, width=768,
seed=0, rand_device="cuda",
)
image.save("image_FLUX.1-dev-Controlnet-Upscaler_lora.jpg")

View File

@@ -0,0 +1,26 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"),
ModelConfig(model_id="google/siglip-so400m-patch14-384"),
],
)
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-IP-Adapter_lora/epoch-4.safetensors", alpha=1)
image = pipe(
prompt="dog,white and brown dog, sitting on wall, under pink flowers",
ipadapter_images=Image.open("data/example_image_dataset/1.jpg"),
height=768, width=768,
seed=0
)
image.save("image_FLUX.1-dev-IP-Adapter_lora.jpg")

View File

@@ -0,0 +1,28 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin"),
ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors"),
],
)
pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-InfiniteYou_lora/epoch-4.safetensors", alpha=1)
image = pipe(
prompt="a man with a red hat",
controlnet_inputs=[ControlNetInput(
image=Image.open("data/example_image_dataset/infiniteyou/image_1.jpg"),
)],
height=1024, width=1024,
seed=0, rand_device="cuda",
)
image.save("image_FLUX.1-dev-InfiniteYou_lora.jpg")

View File

@@ -0,0 +1,23 @@
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from PIL import Image
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct"),
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"),
],
)
pipe.load_lora(pipe.dit, "models/train/Step1X-Edit_lora/epoch-4.safetensors", alpha=1)
image = pipe(
prompt="Make the dog turn its head around.",
step1x_reference_image=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)),
height=768, width=768, cfg_scale=6,
seed=0
)
image.save("image_Step1X-Edit_lora.jpg")