mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
flux.2
This commit is contained in:
@@ -429,5 +429,26 @@ flux_series = [
|
|||||||
"extra_kwargs": {"disable_guidance_embedder": True},
|
"extra_kwargs": {"disable_guidance_embedder": True},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
flux2_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
|
||||||
|
"model_hash": "28fca3d8e5bf2a2d1271748a773f6757",
|
||||||
|
"model_name": "flux2_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f",
|
||||||
|
"model_name": "flux2_dit",
|
||||||
|
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "c54288e3ee12ca215898840682337b95",
|
||||||
|
"model_name": "flux2_vae",
|
||||||
|
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series
|
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series
|
||||||
|
|||||||
@@ -150,4 +150,9 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
|||||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
},
|
},
|
||||||
|
"diffsynth.models.flux2_dit.Flux2DiT": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
1057
diffsynth/models/flux2_dit.py
Normal file
1057
diffsynth/models/flux2_dit.py
Normal file
File diff suppressed because it is too large
Load Diff
58
diffsynth/models/flux2_text_encoder.py
Normal file
58
diffsynth/models/flux2_text_encoder.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from transformers import Mistral3ForConditionalGeneration, Mistral3Config
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2TextEncoder(Mistral3ForConditionalGeneration):
|
||||||
|
def __init__(self):
|
||||||
|
config = Mistral3Config(**{
|
||||||
|
"architectures": [
|
||||||
|
"Mistral3ForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"image_token_index": 10,
|
||||||
|
"model_type": "mistral3",
|
||||||
|
"multimodal_projector_bias": False,
|
||||||
|
"projector_hidden_act": "gelu",
|
||||||
|
"spatial_merge_size": 2,
|
||||||
|
"text_config": {
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 5120,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 32768,
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"model_type": "mistral",
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_hidden_layers": 40,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"rope_theta": 1000000000.0,
|
||||||
|
"sliding_window": None,
|
||||||
|
"use_cache": True,
|
||||||
|
"vocab_size": 131072
|
||||||
|
},
|
||||||
|
"transformers_version": "4.57.1",
|
||||||
|
"vision_config": {
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"head_dim": 64,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"image_size": 1540,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 4096,
|
||||||
|
"model_type": "pixtral",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 24,
|
||||||
|
"patch_size": 14,
|
||||||
|
"rope_theta": 10000.0
|
||||||
|
},
|
||||||
|
"vision_feature_layer": -1
|
||||||
|
})
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs):
|
||||||
|
return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)
|
||||||
|
|
||||||
468
diffsynth/models/flux2_vae.py
Normal file
468
diffsynth/models/flux2_vae.py
Normal file
@@ -0,0 +1,468 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import math
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from diffusers.models.autoencoders.autoencoder_kl_flux2 import Decoder, Encoder
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2VAE(torch.nn.Module):
|
||||||
|
r"""
|
||||||
|
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||||
|
|
||||||
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||||
|
for all models (such as downloading or saving).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||||
|
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||||
|
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||||
|
Tuple of downsample block types.
|
||||||
|
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||||
|
Tuple of upsample block types.
|
||||||
|
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||||
|
Tuple of block output channels.
|
||||||
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||||
|
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
||||||
|
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||||
|
force_upcast (`bool`, *optional*, default to `True`):
|
||||||
|
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
||||||
|
can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
|
||||||
|
can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
||||||
|
mid_block_add_attention (`bool`, *optional*, default to `True`):
|
||||||
|
If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
|
||||||
|
mid_block will only have resnet blocks
|
||||||
|
"""
|
||||||
|
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 3,
|
||||||
|
out_channels: int = 3,
|
||||||
|
down_block_types: Tuple[str, ...] = (
|
||||||
|
"DownEncoderBlock2D",
|
||||||
|
"DownEncoderBlock2D",
|
||||||
|
"DownEncoderBlock2D",
|
||||||
|
"DownEncoderBlock2D",
|
||||||
|
),
|
||||||
|
up_block_types: Tuple[str, ...] = (
|
||||||
|
"UpDecoderBlock2D",
|
||||||
|
"UpDecoderBlock2D",
|
||||||
|
"UpDecoderBlock2D",
|
||||||
|
"UpDecoderBlock2D",
|
||||||
|
),
|
||||||
|
block_out_channels: Tuple[int, ...] = (
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
),
|
||||||
|
layers_per_block: int = 2,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
latent_channels: int = 32,
|
||||||
|
norm_num_groups: int = 32,
|
||||||
|
sample_size: int = 1024, # YiYi notes: not sure
|
||||||
|
force_upcast: bool = True,
|
||||||
|
use_quant_conv: bool = True,
|
||||||
|
use_post_quant_conv: bool = True,
|
||||||
|
mid_block_add_attention: bool = True,
|
||||||
|
batch_norm_eps: float = 1e-4,
|
||||||
|
batch_norm_momentum: float = 0.1,
|
||||||
|
patch_size: Tuple[int, int] = (2, 2),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# pass init params to Encoder
|
||||||
|
self.encoder = Encoder(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=latent_channels,
|
||||||
|
down_block_types=down_block_types,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
act_fn=act_fn,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
double_z=True,
|
||||||
|
mid_block_add_attention=mid_block_add_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
# pass init params to Decoder
|
||||||
|
self.decoder = Decoder(
|
||||||
|
in_channels=latent_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
up_block_types=up_block_types,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
act_fn=act_fn,
|
||||||
|
mid_block_add_attention=mid_block_add_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
|
||||||
|
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
|
||||||
|
|
||||||
|
self.bn = nn.BatchNorm2d(
|
||||||
|
math.prod(patch_size) * latent_channels,
|
||||||
|
eps=batch_norm_eps,
|
||||||
|
momentum=batch_norm_momentum,
|
||||||
|
affine=False,
|
||||||
|
track_running_stats=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_slicing = False
|
||||||
|
self.use_tiling = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||||
|
def attn_processors(self):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||||
|
indexed by its weight name.
|
||||||
|
"""
|
||||||
|
# set recursively
|
||||||
|
processors = {}
|
||||||
|
|
||||||
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
|
||||||
|
if hasattr(module, "get_processor"):
|
||||||
|
processors[f"{name}.processor"] = module.get_processor()
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_add_processors(name, module, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||||
|
def set_attn_processor(self, processor):
|
||||||
|
r"""
|
||||||
|
Sets the attention processor to use to compute attention.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||||
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||||
|
for **all** `Attention` layers.
|
||||||
|
|
||||||
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||||
|
processor. This is strongly recommended when setting trainable attention processors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
count = len(self.attn_processors.keys())
|
||||||
|
|
||||||
|
if isinstance(processor, dict) and len(processor) != count:
|
||||||
|
raise ValueError(
|
||||||
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||||
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
if not isinstance(processor, dict):
|
||||||
|
module.set_processor(processor)
|
||||||
|
else:
|
||||||
|
module.set_processor(processor.pop(f"{name}.processor"))
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
|
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, num_channels, height, width = x.shape
|
||||||
|
|
||||||
|
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
||||||
|
return self._tiled_encode(x)
|
||||||
|
|
||||||
|
enc = self.encoder(x)
|
||||||
|
if self.quant_conv is not None:
|
||||||
|
enc = self.quant_conv(enc)
|
||||||
|
|
||||||
|
return enc
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, x: torch.Tensor, return_dict: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encode a batch of images into latents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (`torch.Tensor`): Input batch of images.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The latent representations of the encoded images. If `return_dict` is True, a
|
||||||
|
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||||
|
"""
|
||||||
|
if self.use_slicing and x.shape[0] > 1:
|
||||||
|
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||||
|
h = torch.cat(encoded_slices)
|
||||||
|
else:
|
||||||
|
h = self._encode(x)
|
||||||
|
|
||||||
|
|
||||||
|
h = rearrange(h, "B C (H P) (W Q) -> B (C P Q) H W", P=2, Q=2)
|
||||||
|
h = h[:, :128]
|
||||||
|
latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(h.device, h.dtype)
|
||||||
|
latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to(
|
||||||
|
h.device, h.dtype
|
||||||
|
)
|
||||||
|
h = (h - latents_bn_mean) / latents_bn_std
|
||||||
|
return h
|
||||||
|
|
||||||
|
def _decode(self, z: torch.Tensor, return_dict: bool = True):
|
||||||
|
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
||||||
|
return self.tiled_decode(z, return_dict=return_dict)
|
||||||
|
|
||||||
|
if self.post_quant_conv is not None:
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
|
||||||
|
dec = self.decoder(z)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (dec,)
|
||||||
|
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
||||||
|
):
|
||||||
|
latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(z.device, z.dtype)
|
||||||
|
latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to(
|
||||||
|
z.device, z.dtype
|
||||||
|
)
|
||||||
|
z = z * latents_bn_std + latents_bn_mean
|
||||||
|
z = rearrange(z, "B (C P Q) H W -> B C (H P) (W Q)", P=2, Q=2)
|
||||||
|
"""
|
||||||
|
Decode a batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z (`torch.Tensor`): Input batch of latent vectors.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||||
|
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||||
|
returned.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.use_slicing and z.shape[0] > 1:
|
||||||
|
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||||
|
decoded = torch.cat(decoded_slices)
|
||||||
|
else:
|
||||||
|
decoded = self._decode(z)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (decoded,)
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||||
|
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||||
|
for y in range(blend_extent):
|
||||||
|
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||||
|
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||||
|
for x in range(blend_extent):
|
||||||
|
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
r"""Encode a batch of images using a tiled encoder.
|
||||||
|
|
||||||
|
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||||
|
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||||
|
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||||
|
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||||
|
output, but they should be much less noticeable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (`torch.Tensor`): Input batch of images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`:
|
||||||
|
The latent representation of the encoded videos.
|
||||||
|
"""
|
||||||
|
|
||||||
|
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||||
|
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||||
|
row_limit = self.tile_latent_min_size - blend_extent
|
||||||
|
|
||||||
|
# Split the image into 512x512 tiles and encode them separately.
|
||||||
|
rows = []
|
||||||
|
for i in range(0, x.shape[2], overlap_size):
|
||||||
|
row = []
|
||||||
|
for j in range(0, x.shape[3], overlap_size):
|
||||||
|
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||||
|
tile = self.encoder(tile)
|
||||||
|
if self.config.use_quant_conv:
|
||||||
|
tile = self.quant_conv(tile)
|
||||||
|
row.append(tile)
|
||||||
|
rows.append(row)
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||||
|
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=3))
|
||||||
|
|
||||||
|
enc = torch.cat(result_rows, dim=2)
|
||||||
|
return enc
|
||||||
|
|
||||||
|
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True):
|
||||||
|
r"""Encode a batch of images using a tiled encoder.
|
||||||
|
|
||||||
|
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||||
|
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||||
|
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||||
|
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||||
|
output, but they should be much less noticeable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (`torch.Tensor`): Input batch of images.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
||||||
|
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
||||||
|
`tuple` is returned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||||
|
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||||
|
row_limit = self.tile_latent_min_size - blend_extent
|
||||||
|
|
||||||
|
# Split the image into 512x512 tiles and encode them separately.
|
||||||
|
rows = []
|
||||||
|
for i in range(0, x.shape[2], overlap_size):
|
||||||
|
row = []
|
||||||
|
for j in range(0, x.shape[3], overlap_size):
|
||||||
|
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||||
|
tile = self.encoder(tile)
|
||||||
|
if self.config.use_quant_conv:
|
||||||
|
tile = self.quant_conv(tile)
|
||||||
|
row.append(tile)
|
||||||
|
rows.append(row)
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||||
|
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=3))
|
||||||
|
|
||||||
|
moments = torch.cat(result_rows, dim=2)
|
||||||
|
return moments
|
||||||
|
|
||||||
|
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True):
|
||||||
|
r"""
|
||||||
|
Decode a batch of images using a tiled decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z (`torch.Tensor`): Input batch of latent vectors.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||||
|
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||||
|
returned.
|
||||||
|
"""
|
||||||
|
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
||||||
|
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
||||||
|
row_limit = self.tile_sample_min_size - blend_extent
|
||||||
|
|
||||||
|
# Split z into overlapping 64x64 tiles and decode them separately.
|
||||||
|
# The tiles have an overlap to avoid seams between tiles.
|
||||||
|
rows = []
|
||||||
|
for i in range(0, z.shape[2], overlap_size):
|
||||||
|
row = []
|
||||||
|
for j in range(0, z.shape[3], overlap_size):
|
||||||
|
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||||
|
if self.config.use_post_quant_conv:
|
||||||
|
tile = self.post_quant_conv(tile)
|
||||||
|
decoded = self.decoder(tile)
|
||||||
|
row.append(decoded)
|
||||||
|
rows.append(row)
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||||
|
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=3))
|
||||||
|
|
||||||
|
dec = torch.cat(result_rows, dim=2)
|
||||||
|
if not return_dict:
|
||||||
|
return (dec,)
|
||||||
|
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
sample_posterior: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
sample (`torch.Tensor`): Input sample.
|
||||||
|
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to sample from the posterior.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
x = sample
|
||||||
|
posterior = self.encode(x).latent_dist
|
||||||
|
if sample_posterior:
|
||||||
|
z = posterior.sample(generator=generator)
|
||||||
|
else:
|
||||||
|
z = posterior.mode()
|
||||||
|
dec = self.decode(z).sample
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (dec,)
|
||||||
|
|
||||||
|
return dec
|
||||||
371
diffsynth/pipelines/flux2_image.py
Normal file
371
diffsynth/pipelines/flux2_image.py
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
import torch, math
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Union
|
||||||
|
from tqdm import tqdm
|
||||||
|
from einops import rearrange
|
||||||
|
import numpy as np
|
||||||
|
from typing import Union, List, Optional, Tuple
|
||||||
|
|
||||||
|
from ..diffusion import FlowMatchScheduler
|
||||||
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||||
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||||
|
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
from ..models.flux2_text_encoder import Flux2TextEncoder
|
||||||
|
from ..models.flux2_dit import Flux2DiT
|
||||||
|
from ..models.flux2_vae import Flux2VAE
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2ImagePipeline(BasePipeline):
|
||||||
|
|
||||||
|
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
)
|
||||||
|
self.scheduler = FlowMatchScheduler()
|
||||||
|
self.text_encoder: Flux2TextEncoder = None
|
||||||
|
self.dit: Flux2DiT = None
|
||||||
|
self.vae: Flux2VAE = None
|
||||||
|
self.tokenizer: AutoProcessor = None
|
||||||
|
self.in_iteration_models = ("dit",)
|
||||||
|
self.units = [
|
||||||
|
Flux2Unit_ShapeChecker(),
|
||||||
|
Flux2Unit_PromptEmbedder(),
|
||||||
|
Flux2Unit_NoiseInitializer(),
|
||||||
|
Flux2Unit_InputImageEmbedder(),
|
||||||
|
Flux2Unit_ImageIDs(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_flux2
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
# Initialize pipeline
|
||||||
|
pipe = Flux2ImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
|
# Fetch models
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder")
|
||||||
|
pipe.dit = model_pool.fetch_model("flux2_dit")
|
||||||
|
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||||
|
if tokenizer_config is not None:
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = AutoProcessor.from_pretrained(tokenizer_config.path)
|
||||||
|
|
||||||
|
# VRAM Management
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 1.0,
|
||||||
|
embedded_guidance: float = 4.0,
|
||||||
|
# Image
|
||||||
|
input_image: Image.Image = None,
|
||||||
|
denoising_strength: float = 1.0,
|
||||||
|
# Shape
|
||||||
|
height: int = 1024,
|
||||||
|
width: int = 1024,
|
||||||
|
# Randomness
|
||||||
|
seed: int = None,
|
||||||
|
rand_device: str = "cpu",
|
||||||
|
# Steps
|
||||||
|
num_inference_steps: int = 30,
|
||||||
|
# Progress bar
|
||||||
|
progress_bar_cmd = tqdm,
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
inputs_posi = {
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
inputs_nega = {
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
|
||||||
|
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||||
|
"height": height, "width": width,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
}
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
latents = rearrange(inputs_shared["latents"], "B (H W) C -> B C H W", H=inputs_shared["height"]//16, W=inputs_shared["width"]//16)
|
||||||
|
image = self.vae.decode(latents)
|
||||||
|
image = self.vae_output_to_image(image)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2Unit_ShapeChecker(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("height", "width"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: Flux2ImagePipeline, height, width):
|
||||||
|
height, width = pipe.check_resize_height_width(height, width)
|
||||||
|
return {"height": height, "width": width}
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2Unit_PromptEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt"},
|
||||||
|
output_params=("prompt_emb", "prompt_emb_mask"),
|
||||||
|
onload_model_names=("text_encoder",)
|
||||||
|
)
|
||||||
|
self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
|
||||||
|
|
||||||
|
def format_text_input(self, prompts: List[str], system_message: str = None):
|
||||||
|
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
|
||||||
|
# when truncation is enabled. The processor counts [IMG] tokens and fails
|
||||||
|
# if the count changes after truncation.
|
||||||
|
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
|
||||||
|
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [{"type": "text", "text": system_message}],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||||
|
]
|
||||||
|
for prompt in cleaned_txt
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_mistral_3_small_prompt_embeds(
|
||||||
|
self,
|
||||||
|
text_encoder,
|
||||||
|
tokenizer,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
# fmt: off
|
||||||
|
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
|
||||||
|
# fmt: on
|
||||||
|
hidden_states_layers: List[int] = (10, 20, 30),
|
||||||
|
):
|
||||||
|
dtype = text_encoder.dtype if dtype is None else dtype
|
||||||
|
device = text_encoder.device if device is None else device
|
||||||
|
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
|
||||||
|
# Format input messages
|
||||||
|
messages_batch = self.format_text_input(prompts=prompt, system_message=system_message)
|
||||||
|
|
||||||
|
# Process all messages at once
|
||||||
|
inputs = tokenizer.apply_chat_template(
|
||||||
|
messages_batch,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move to device
|
||||||
|
input_ids = inputs["input_ids"].to(device)
|
||||||
|
attention_mask = inputs["attention_mask"].to(device)
|
||||||
|
|
||||||
|
# Forward pass through the model
|
||||||
|
output = text_encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only use outputs from intermediate layers and stack them
|
||||||
|
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||||
|
out = out.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||||
|
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||||
|
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
def prepare_text_ids(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor, # (B, L, D) or (L, D)
|
||||||
|
t_coord: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
B, L, _ = x.shape
|
||||||
|
out_ids = []
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||||
|
h = torch.arange(1)
|
||||||
|
w = torch.arange(1)
|
||||||
|
l = torch.arange(L)
|
||||||
|
|
||||||
|
coords = torch.cartesian_prod(t, h, w, l)
|
||||||
|
out_ids.append(coords)
|
||||||
|
|
||||||
|
return torch.stack(out_ids)
|
||||||
|
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
text_encoder,
|
||||||
|
tokenizer,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
dtype = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
text_encoder_out_layers: Tuple[int] = (10, 20, 30),
|
||||||
|
):
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
|
||||||
|
if prompt_embeds is None:
|
||||||
|
prompt_embeds = self.get_mistral_3_small_prompt_embeds(
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
prompt=prompt,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
system_message=self.system_message,
|
||||||
|
hidden_states_layers=text_encoder_out_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, seq_len, _ = prompt_embeds.shape
|
||||||
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
text_ids = self.prepare_text_ids(prompt_embeds)
|
||||||
|
text_ids = text_ids.to(device)
|
||||||
|
return prompt_embeds, text_ids
|
||||||
|
|
||||||
|
def process(self, pipe: Flux2ImagePipeline, prompt):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
prompt_embeds, text_ids = self.encode_prompt(
|
||||||
|
pipe.text_encoder, pipe.tokenizer, prompt,
|
||||||
|
dtype=pipe.torch_dtype, device=pipe.device,
|
||||||
|
)
|
||||||
|
return {"prompt_embeds": prompt_embeds, "text_ids": text_ids}
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2Unit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "seed", "rand_device"),
|
||||||
|
output_params=("noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: Flux2ImagePipeline, height, width, seed, rand_device):
|
||||||
|
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||||
|
noise = noise.reshape(1, 128, height//16 * width//16).permute(0, 2, 1)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2Unit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: Flux2ImagePipeline, input_image, noise):
|
||||||
|
if input_image is None:
|
||||||
|
return {"latents": noise, "input_latents": None}
|
||||||
|
pipe.load_models_to_device(['vae'])
|
||||||
|
image = pipe.preprocess_image(input_image)
|
||||||
|
input_latents = pipe.vae.encode(image)
|
||||||
|
input_latents = rearrange(input_latents, "B C H W -> B (H W) C")
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"latents": noise, "input_latents": input_latents}
|
||||||
|
else:
|
||||||
|
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||||
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2Unit_ImageIDs(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width"),
|
||||||
|
output_params=("image_ids",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_latent_ids(self, height, width):
|
||||||
|
t = torch.arange(1) # [0] - time dimension
|
||||||
|
h = torch.arange(height)
|
||||||
|
w = torch.arange(width)
|
||||||
|
l = torch.arange(1) # [0] - layer dimension
|
||||||
|
|
||||||
|
# Create position IDs: (H*W, 4)
|
||||||
|
latent_ids = torch.cartesian_prod(t, h, w, l)
|
||||||
|
|
||||||
|
# Expand to batch: (B, H*W, 4)
|
||||||
|
latent_ids = latent_ids.unsqueeze(0).expand(1, -1, -1)
|
||||||
|
|
||||||
|
return latent_ids
|
||||||
|
|
||||||
|
def process(self, pipe: Flux2ImagePipeline, height, width):
|
||||||
|
image_ids = self.prepare_latent_ids(height // 16, width // 16).to(pipe.device)
|
||||||
|
return {"image_ids": image_ids}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_flux2(
|
||||||
|
dit: Flux2DiT,
|
||||||
|
latents=None,
|
||||||
|
timestep=None,
|
||||||
|
embedded_guidance=None,
|
||||||
|
prompt_embeds=None,
|
||||||
|
text_ids=None,
|
||||||
|
image_ids=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
|
||||||
|
model_output = dit(
|
||||||
|
hidden_states=latents,
|
||||||
|
timestep=timestep / 1000,
|
||||||
|
guidance=embedded_guidance,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
txt_ids=text_ids,
|
||||||
|
img_ids=image_ids,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
return model_output
|
||||||
17
diffsynth/utils/state_dict_converters/flux2_text_encoder.py
Normal file
17
diffsynth/utils/state_dict_converters/flux2_text_encoder.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
def Flux2TextEncoderStateDictConverter(state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"multi_modal_projector.linear_1.weight": "model.multi_modal_projector.linear_1.weight",
|
||||||
|
"multi_modal_projector.linear_2.weight": "model.multi_modal_projector.linear_2.weight",
|
||||||
|
"multi_modal_projector.norm.weight": "model.multi_modal_projector.norm.weight",
|
||||||
|
"multi_modal_projector.patch_merger.merging_layer.weight": "model.multi_modal_projector.patch_merger.merging_layer.weight",
|
||||||
|
"language_model.lm_head.weight": "lm_head.weight",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for k in state_dict:
|
||||||
|
k_ = k
|
||||||
|
k_ = k_.replace("language_model.model", "model.language_model")
|
||||||
|
k_ = k_.replace("vision_tower", "model.vision_tower")
|
||||||
|
if k_ in rename_dict:
|
||||||
|
k_ = rename_dict[k_]
|
||||||
|
state_dict_[k_] = state_dict[k]
|
||||||
|
return state_dict_
|
||||||
27
examples/flux2/model_inference/FLUX.2-dev.py
Normal file
27
examples/flux2/model_inference/FLUX.2-dev.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cuda",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = Flux2ImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
|
||||||
|
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
|
||||||
|
image.save("image_FLUX.2-dev.jpg")
|
||||||
32
examples/flux2/model_training/lora/FLUX.2-dev.sh
Normal file
32
examples/flux2/model_training/lora/FLUX.2-dev.sh
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
accelerate launch train.py \
|
||||||
|
--dataset_base_path data/example_image_dataset \
|
||||||
|
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 1 \
|
||||||
|
--model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors,black-forest-labs/FLUX.2-dev:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8 \
|
||||||
|
--task "sft:data_process"
|
||||||
|
|
||||||
|
accelerate launch train.py \
|
||||||
|
--dataset_base_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:transformer/*.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/FLUX.2-dev-LoRA-splited" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8 \
|
||||||
|
--task "sft:train"
|
||||||
143
examples/flux2/model_training/train.py
Normal file
143
examples/flux2/model_training/train.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import torch, os, argparse, accelerate
|
||||||
|
from diffsynth.core import UnifiedDataset
|
||||||
|
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||||
|
from diffsynth.diffusion import *
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2ImageTrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_paths=None, model_id_with_origin_paths=None,
|
||||||
|
tokenizer_path=None,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
|
preset_lora_path=None, preset_lora_model=None,
|
||||||
|
use_gradient_checkpointing=True,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
extra_inputs=None,
|
||||||
|
fp8_models=None,
|
||||||
|
offload_models=None,
|
||||||
|
device="cpu",
|
||||||
|
task="sft",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Load models
|
||||||
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||||
|
tokenizer_config = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||||
|
self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||||
|
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||||
|
|
||||||
|
# Training mode
|
||||||
|
self.switch_pipe_to_training_mode(
|
||||||
|
self.pipe, trainable_models,
|
||||||
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||||
|
preset_lora_path, preset_lora_model,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other configs
|
||||||
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
|
self.fp8_models = fp8_models
|
||||||
|
self.task = task
|
||||||
|
self.task_to_loss = {
|
||||||
|
"sft:data_process": lambda pipe, *args: args,
|
||||||
|
"direct_distill:data_process": lambda pipe, *args: args,
|
||||||
|
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_pipeline_inputs(self, data):
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {"negative_prompt": ""}
|
||||||
|
inputs_shared = {
|
||||||
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
# Please do not modify the following parameters
|
||||||
|
# unless you clearly know what this will cause.
|
||||||
|
"embedded_guidance": 1.0,
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"rand_device": self.pipe.device,
|
||||||
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
|
}
|
||||||
|
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
def forward(self, data, inputs=None):
|
||||||
|
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||||
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||||
|
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def qwen_image_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser = add_general_config(parser)
|
||||||
|
parser = add_image_size_config(parser)
|
||||||
|
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = qwen_image_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
accelerator = accelerate.Accelerator(
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||||
|
)
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
metadata_path=args.dataset_metadata_path,
|
||||||
|
repeat=args.dataset_repeat,
|
||||||
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
height_division_factor=16,
|
||||||
|
width_division_factor=16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model = Flux2ImageTrainingModule(
|
||||||
|
model_paths=args.model_paths,
|
||||||
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
tokenizer_path=args.tokenizer_path,
|
||||||
|
trainable_models=args.trainable_models,
|
||||||
|
lora_base_model=args.lora_base_model,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_checkpoint=args.lora_checkpoint,
|
||||||
|
preset_lora_path=args.preset_lora_path,
|
||||||
|
preset_lora_model=args.preset_lora_model,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
|
extra_inputs=args.extra_inputs,
|
||||||
|
fp8_models=args.fp8_models,
|
||||||
|
offload_models=args.offload_models,
|
||||||
|
task=args.task,
|
||||||
|
device=accelerator.device,
|
||||||
|
)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
args.output_path,
|
||||||
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
|
)
|
||||||
|
launcher_map = {
|
||||||
|
"sft:data_process": launch_data_process_task,
|
||||||
|
"direct_distill:data_process": launch_data_process_task,
|
||||||
|
"sft": launch_training_task,
|
||||||
|
"sft:train": launch_training_task,
|
||||||
|
"direct_distill": launch_training_task,
|
||||||
|
"direct_distill:train": launch_training_task,
|
||||||
|
}
|
||||||
|
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||||
28
examples/flux2/model_training/validate_lora/FLUX.2-dev.py
Normal file
28
examples/flux2/model_training/validate_lora/FLUX.2-dev.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cuda",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = Flux2ImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.dit, "./models/train/FLUX.2-dev-LoRA-splited/epoch-4.safetensors")
|
||||||
|
prompt = "a dog is jumping"
|
||||||
|
image = pipe(prompt, seed=0)
|
||||||
|
image.save("image_FLUX.2-dev_lora.jpg")
|
||||||
Reference in New Issue
Block a user