flux-refactor

This commit is contained in:
Artiprocher
2025-06-27 10:20:11 +08:00
parent b603acd36a
commit fcf2fbc07f
15 changed files with 456 additions and 7 deletions

View File

@@ -18,10 +18,13 @@ from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEnc
from ..models.step1x_connector import Qwen2Connector
from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter
from ..models.flux_infiniteyou import InfiniteYouImageProjector
from ..models.tiler import FastTileWorker
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader
from ..vram_management import gradient_checkpoint_forward
@dataclass
@@ -89,6 +92,8 @@ class FluxImagePipeline(BasePipeline):
self.unit_runner = PipelineUnitRunner()
self.qwenvl = None
self.step1x_connector: Qwen2Connector = None
self.infinityou_processor: InfinitYou = None
self.image_proj_model: InfiniteYouImageProjector = None
self.in_iteration_models = ("dit", "step1x_connector", "controlnet")
self.units = [
FluxImageUnit_ShapeChecker(),
@@ -209,7 +214,7 @@ class FluxImagePipeline(BasePipeline):
# ControlNet
controlnet_inputs: list[ControlNetInput] = None,
# IP-Adapter
ipadapter_images: list[Image.Image] = None,
ipadapter_images: Union[list[Image.Image], Image.Image] = None,
ipadapter_scale: float = 1.0,
# EliGen
eligen_entity_prompts: list[str] = None,
@@ -426,6 +431,8 @@ class FluxImageUnit_IPAdapter(PipelineUnit):
ipadapter_images, ipadapter_scale = inputs_shared.get("ipadapter_images", None), inputs_shared.get("ipadapter_scale", 1.0)
if ipadapter_images is None:
return inputs_shared, inputs_posi, inputs_nega
if not isinstance(ipadapter_images, list):
ipadapter_images = [ipadapter_images]
pipe.load_models_to_device(self.onload_model_names)
images = [image.convert("RGB").resize((384, 384), resample=3) for image in ipadapter_images]
@@ -700,6 +707,8 @@ def model_fn_flux_image(
tea_cache: TeaCache = None,
progress_id=0,
num_inference_steps=1,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs
):
if tiled:
@@ -805,13 +814,16 @@ def model_fn_flux_image(
else:
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(
hidden_states, prompt_emb = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None),
)
# ControlNet
if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None:
@@ -821,13 +833,16 @@ def model_fn_flux_image(
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(
hidden_states, prompt_emb = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
)
# ControlNet
if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None:

View File

@@ -178,7 +178,7 @@ class ModelConfig:
skip_download = dist.get_rank() != 0
# Check whether the origin path is a folder
if self.origin_file_pattern is None:
if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.origin_file_pattern = ""
allow_file_pattern = None
is_folder = True

View File

@@ -7,6 +7,127 @@ from accelerate import Accelerator
class ImageDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path=None, metadata_path=None,
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
data_file_keys=("image",),
image_file_extension=("jpg", "jpeg", "png", "webp"),
repeat=1,
args=None,
):
if args is not None:
base_path = args.dataset_base_path
metadata_path = args.dataset_metadata_path
height = args.height
width = args.width
max_pixels = args.max_pixels
data_file_keys = args.data_file_keys.split(",")
repeat = args.dataset_repeat
self.base_path = base_path
self.max_pixels = max_pixels
self.height = height
self.width = width
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.data_file_keys = data_file_keys
self.image_file_extension = image_file_extension
self.repeat = repeat
if height is not None and width is not None:
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
self.dynamic_resolution = False
elif height is None and width is None:
print("Height and width are none. Setting `dynamic_resolution` to True.")
self.dynamic_resolution = True
if metadata_path is None:
print("No metadata. Trying to generate it.")
metadata = self.generate_metadata(base_path)
print(f"{len(metadata)} lines in metadata.")
else:
metadata = pd.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
def generate_metadata(self, folder):
image_list, prompt_list = [], []
file_set = set(os.listdir(folder))
for file_name in file_set:
if "." not in file_name:
continue
file_ext_name = file_name.split(".")[-1].lower()
file_base_name = file_name[:-len(file_ext_name)-1]
if file_ext_name not in self.image_file_extension:
continue
prompt_file_name = file_base_name + ".txt"
if prompt_file_name not in file_set:
continue
with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f:
prompt = f.read().strip()
image_list.append(file_name)
prompt_list.append(prompt)
metadata = pd.DataFrame()
metadata["image"] = image_list
metadata["prompt"] = prompt_list
return metadata
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def get_height_width(self, image):
if self.dynamic_resolution:
width, height = image.size
if width * height > self.max_pixels:
scale = (width * height / self.max_pixels) ** 0.5
height, width = int(height / scale), int(width / scale)
height = height // self.height_division_factor * self.height_division_factor
width = width // self.width_division_factor * self.width_division_factor
else:
height, width = self.height, self.width
return height, width
def load_image(self, file_path):
image = Image.open(file_path).convert("RGB")
image = self.crop_and_resize(image, *self.get_height_width(image))
return image
def load_data(self, file_path):
return self.load_image(file_path)
def __getitem__(self, data_id):
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
path = os.path.join(self.base_path, data[key])
data[key] = self.load_data(path)
if data[key] is None:
warnings.warn(f"cannot load file {data[key]}.")
return None
return data
def __len__(self):
return len(self.data) * self.repeat
class VideoDataset(torch.utils.data.Dataset):
def __init__(
self,
@@ -218,9 +339,10 @@ class DiffusionTrainingModule(torch.nn.Module):
class ModelLogger:
def __init__(self, output_path, remove_prefix_in_ckpt=None):
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
self.state_dict_converter = state_dict_converter
def on_step_end(self, loss):
@@ -232,6 +354,7 @@ class ModelLogger:
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
@@ -302,3 +425,30 @@ def wan_parser():
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
return parser
def flux_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution..")
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.")
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
return parser

View File

@@ -1 +1,2 @@
from .layers import *
from .gradient_checkpointing import *

View File

@@ -0,0 +1,34 @@
import torch
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
def gradient_checkpoint_forward(
model,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*args,
**kwargs,
):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
else:
model_output = model(*args, **kwargs)
return model_output