Files
DiffSynth-Studio/examples/flux/README.md
Artiprocher 44e2eecdf1 flux-kontext
2025-06-29 15:59:04 +08:00

15 KiB

FLUX

Switch to Chinese

FLUX is a series of image generation models open-sourced by Black-Forest-Labs.

DiffSynth-Studio has introduced a new inference and training framework. If you need to use the old version, please click here.

Installation

Before using these models, please install DiffSynth-Studio from source code:

git clone https://github.com/modelscope/DiffSynth-Studio.git  
cd DiffSynth-Studio
pip install -e .

Quick Start

You can quickly load the FLUX.1-dev model and perform inference by running the following code:

import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig

pipe = FluxImagePipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    model_configs=[
        ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
        ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
        ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
        ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
    ],
)

image = pipe(prompt="a cat", seed=0)
image.save("image.jpg")

Model Overview

Support for the new framework of the FLUX series models is under active development. Stay tuned!

Model ID Additional Parameters Inference Full Training Validation After Full Training LoRA Training Validation After LoRA Training
black-forest-labs/FLUX.1-dev code code code code code
black-forest-labs/FLUX.1-Kontext-dev kontext_images code code code

Model Inference

The following sections will help you understand our features and write inference code.

Loading Models

Models are loaded using from_pretrained:

pipe = FluxImagePipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    model_configs=[
        ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
        ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
        ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
        ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
    ],
)

Here, torch_dtype and device refer to the computation precision and device, respectively. The model_configs can be configured in various ways to specify model paths:

  • Download the model from ModelScope Community and load it. In this case, provide model_id and origin_file_pattern, for example:
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
  • Load the model from a local file path. In this case, provide the path, for example:
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors")

For models that consist of multiple files, use a list as follows:

ModelConfig(path=[
    "models/xxx/diffusion_pytorch_model-00001-of-00003.safetensors",
    "models/xxx/diffusion_pytorch_model-00002-of-00003.safetensors",
    "models/xxx/diffusion_pytorch_model-00003-of-00003.safetensors",
])

The from_pretrained method also provides additional parameters to control model loading behavior:

  • local_model_path: Path for saving downloaded models. The default is "./models".
  • skip_download: Whether to skip downloading models. The default is False. If your network cannot access ModelScope Community, manually download the required files and set this to True.
VRAM Management

DiffSynth-Studio provides fine-grained VRAM management for FLUX models, enabling inference on devices with limited VRAM. You can enable offloading functionality via the following code, which moves certain modules to system memory on devices with limited GPU memory.

pipe = FluxImagePipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    model_configs=[
        ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu"),
        ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
        ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
        ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
    ],
)
pipe.enable_vram_management()

The enable_vram_management function provides the following parameters to control VRAM usage:

  • vram_limit: VRAM usage limit in GB. By default, it uses the remaining VRAM available on the device. Note that this is not an absolute limit; if the set VRAM is insufficient but more VRAM is actually available, the model will run with minimal VRAM consumption. Setting it to 0 achieves the theoretical minimum VRAM usage.
  • vram_buffer: VRAM buffer size in GB. The default is 0.5GB. Since some large neural network layers may consume extra VRAM during onload phases, a VRAM buffer is necessary. Ideally, the optimal value should match the VRAM occupied by the largest layer in the model.
  • num_persistent_param_in_dit: Number of persistent parameters in the DiT model (default: no limit). We plan to remove this parameter in the future, so please avoid relying on it.
Inference Acceleration
Input Parameters

The pipeline accepts the following input parameters during inference:

  • prompt: Prompt describing what should appear in the image.
  • negative_prompt: Negative prompt describing what should not appear in the image. Default is "".
  • cfg_scale: Classifier-free guidance scale. Default is 1. It becomes effective when set to a value greater than 1.
  • embedded_guidance: Embedded guidance parameter for FLUX-dev. Default is 3.5.
  • t5_sequence_length: Sequence length of T5 text embeddings. Default is 512.
  • input_image: Input image used for image-to-image generation. This works together with denoising_strength.
  • denoising_strength: Denoising strength, ranging from 0 to 1. Default is 1. When close to 0, the generated image will be similar to the input image; when close to 1, the generated image will differ significantly from the input. Do not set this to a non-1 value if no input_image is provided.
  • height: Height of the generated image. Must be a multiple of 16.
  • width: Width of the generated image. Must be a multiple of 16.
  • seed: Random seed. Default is None, meaning completely random.
  • rand_device: Device for generating random Gaussian noise. Default is "cpu". Setting it to "cuda" may lead to different results across GPUs.
  • sigma_shift: Parameter from Rectified Flow theory. Default is 3. A larger value increases the number of steps spent at the beginning of denoising and can improve image quality. However, it may cause inconsistencies between the generation process and training data.
  • num_inference_steps: Number of inference steps. Default is 30.
  • kontext_images: Input images for the Kontext model.
  • controlnet_inputs: Inputs for the ControlNet model.
  • ipadapter_images: Input images for the IP-Adapter model.
  • ipadapter_scale: Control strength of the IP-Adapter model.

Model Training

FLUX series models are trained using a unified script ./model_training/train.py.

Script Parameters

The script supports the following parameters:

  • Dataset
    • --dataset_base_path: Root path to the dataset.
    • --dataset_metadata_path: Path to the metadata file of the dataset.
    • --height: Height of images or videos. Leave height and width empty to enable dynamic resolution.
    • --width: Width of images or videos. Leave height and width empty to enable dynamic resolution.
    • --data_file_keys: Keys in metadata for data files. Comma-separated.
    • --dataset_repeat: Number of times the dataset repeats per epoch.
  • Models
    • --model_paths: Paths to load models. JSON format.
    • --model_id_with_origin_paths: Model IDs with original paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.
  • Training
    • --learning_rate: Learning rate.
    • --num_epochs: Number of training epochs.
    • --output_path: Output path for saving checkpoints.
    • --remove_prefix_in_ckpt: Remove prefix in checkpoint filenames.
  • Trainable Modules
    • --trainable_models: Models that can be trained, e.g., dit, vae, text_encoder.
    • --lora_base_model: Which base model to apply LoRA on.
    • --lora_target_modules: Which layers to apply LoRA on.
    • --lora_rank: Rank of LoRA.
  • Extra Inputs
    • --extra_inputs: Additional model inputs. Comma-separated.
  • VRAM Management
    • use_gradient_checkpointing: Whether to use gradient checkpointing.
    • --use_gradient_checkpointing_offload: Whether to offload gradient checkpointing to CPU memory.
    • gradient_accumulation_steps: Number of steps for gradient accumulation.
  • Miscellaneous
    • --align_to_opensource_format: Whether to align the FLUX DiT LoRA format with the open-source version. Only applicable to LoRA training for FLUX.1-dev and FLUX.1-Kontext-dev.
Step 1: Prepare Dataset

The dataset contains a series of files. We recommend organizing your dataset files as follows:

data/example_video_dataset/
├── metadata.csv
├── image1.jpg
└── image2.jpg

Here, image1.jpg, image2.jpg are training video/image data, and metadata.csv is the metadata list, for example:

video,prompt
image1.jpg,"a cat is sleeping"
image2.jpg,"a dog is running"

We have built a sample image dataset to help you test more conveniently. You can download this dataset using the following command:

modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset

The dataset supports multiple image formats: "jpg", "jpeg", "png", "webp".

The image resolution can be controlled via script parameters --height and --width. When both --height and --width are left empty, dynamic resolution will be enabled, allowing training with the actual width and height of each video or image in the dataset.

We strongly recommend using fixed-resolution training, as there may be load-balancing issues in multi-GPU training with dynamic resolution.

When the model requires additional inputs—for instance, kontext_images required by the controllable model black-forest-labs/FLUX.1-Kontext-dev—please add corresponding columns in the dataset, for example:

video,prompt,kontext_images
image1.jpg,"a cat is sleeping",image1_reference.jpg

If additional inputs include video or image files, you need to specify the column names to parse using the --data_file_keys parameter. You can add more column names accordingly, e.g., --data_file_keys "image,kontext_images".

Step 2: Load Model

Similar to the model loading logic during inference, you can directly configure the model to be loaded using its model ID. For example, during inference we load the model with the following configuration:

model_configs=[
    ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
    ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
    ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
    ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
]

Then during training, simply provide the following parameter to load the corresponding model:

--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors"

If you prefer to load the model from local files, as in the inference example:

model_configs=[
    ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors"),
    ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors"),
    ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder_2/"),
    ModelConfig(path="models/black-forest-labs/FLUX.1-dev/ae.safetensors"),
]

Then during training, set it up as follows:

--model_paths '[
    "models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors",
    "models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors",
    "models/black-forest-labs/FLUX.1-dev/text_encoder_2/",
    "models/black-forest-labs/FLUX.1-dev/ae.safetensors"
]' \
Step 3: Configure Trainable Modules

The training framework supports both full-model training and LoRA-based fine-tuning. Below are some examples:

  • Full training of the DiT module: --trainable_models dit
  • Training a LoRA model on the DiT module: --lora_base_model dit --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" --lora_rank 32

Additionally, since the training script loads multiple modules (text encoder, DiT, VAE), you need to remove prefixes when saving the model files. For example, when performing full DiT training or LoRA training on the DiT module, please set --remove_prefix_in_ckpt pipe.dit.

Step 4: Launch the Training Script

We have written specific training commands for each model. Please refer to the table at the beginning of this document for details.