Merge pull request #646 from modelscope/flux-refactor

Flux refactor
This commit is contained in:
Zhongjie Duan
2025-06-29 18:04:05 +08:00
committed by GitHub
29 changed files with 2706 additions and 4 deletions

View File

@@ -0,0 +1,13 @@
import torch
from diffsynth.lora import GeneralLoRALoader
from diffsynth.models.lora import FluxLoRAFromCivitai
class FluxLoRALoader(GeneralLoRALoader):
def __init__(self, device="cpu", torch_dtype=torch.float32):
super().__init__(device=device, torch_dtype=torch_dtype)
self.loader = FluxLoRAFromCivitai()
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
lora_prefix, model_resource = self.loader.match(model, state_dict_lora)
self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource)

File diff suppressed because it is too large Load Diff

View File

@@ -168,24 +168,48 @@ class ModelConfig:
def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False):
if self.path is None:
if self.model_id is None or self.origin_file_pattern is None:
# Check model_id and origin_file_pattern
if self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
# Skip if not in rank 0
if use_usp:
import torch.distributed as dist
skip_download = dist.get_rank() != 0
# Check whether the origin path is a folder
if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.origin_file_pattern = ""
allow_file_pattern = None
is_folder = True
elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
allow_file_pattern = self.origin_file_pattern + "*"
is_folder = True
else:
allow_file_pattern = self.origin_file_pattern
is_folder = False
# Download
if not skip_download:
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
snapshot_download(
self.model_id,
local_dir=os.path.join(local_model_path, self.model_id),
allow_file_pattern=self.origin_file_pattern,
allow_file_pattern=allow_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
# Let rank 1, 2, ... wait for rank 0
if use_usp:
import torch.distributed as dist
dist.barrier(device_ids=[dist.get_rank()])
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
# Return downloaded files
if is_folder:
self.path = os.path.join(local_model_path, self.model_id, self.origin_file_pattern)
else:
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
if isinstance(self.path, list) and len(self.path) == 1:
self.path = self.path[0]
@@ -614,11 +638,17 @@ class PipelineUnitRunner:
elif unit.seperate_cfg:
# Positive side
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_posi.update(processor_outputs)
# Negative side
if inputs_shared["cfg_scale"] != 1:
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_nega.update(processor_outputs)
else:

View File

@@ -7,6 +7,127 @@ from accelerate import Accelerator
class ImageDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path=None, metadata_path=None,
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
data_file_keys=("image",),
image_file_extension=("jpg", "jpeg", "png", "webp"),
repeat=1,
args=None,
):
if args is not None:
base_path = args.dataset_base_path
metadata_path = args.dataset_metadata_path
height = args.height
width = args.width
max_pixels = args.max_pixels
data_file_keys = args.data_file_keys.split(",")
repeat = args.dataset_repeat
self.base_path = base_path
self.max_pixels = max_pixels
self.height = height
self.width = width
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.data_file_keys = data_file_keys
self.image_file_extension = image_file_extension
self.repeat = repeat
if height is not None and width is not None:
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
self.dynamic_resolution = False
elif height is None and width is None:
print("Height and width are none. Setting `dynamic_resolution` to True.")
self.dynamic_resolution = True
if metadata_path is None:
print("No metadata. Trying to generate it.")
metadata = self.generate_metadata(base_path)
print(f"{len(metadata)} lines in metadata.")
else:
metadata = pd.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
def generate_metadata(self, folder):
image_list, prompt_list = [], []
file_set = set(os.listdir(folder))
for file_name in file_set:
if "." not in file_name:
continue
file_ext_name = file_name.split(".")[-1].lower()
file_base_name = file_name[:-len(file_ext_name)-1]
if file_ext_name not in self.image_file_extension:
continue
prompt_file_name = file_base_name + ".txt"
if prompt_file_name not in file_set:
continue
with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f:
prompt = f.read().strip()
image_list.append(file_name)
prompt_list.append(prompt)
metadata = pd.DataFrame()
metadata["image"] = image_list
metadata["prompt"] = prompt_list
return metadata
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def get_height_width(self, image):
if self.dynamic_resolution:
width, height = image.size
if width * height > self.max_pixels:
scale = (width * height / self.max_pixels) ** 0.5
height, width = int(height / scale), int(width / scale)
height = height // self.height_division_factor * self.height_division_factor
width = width // self.width_division_factor * self.width_division_factor
else:
height, width = self.height, self.width
return height, width
def load_image(self, file_path):
image = Image.open(file_path).convert("RGB")
image = self.crop_and_resize(image, *self.get_height_width(image))
return image
def load_data(self, file_path):
return self.load_image(file_path)
def __getitem__(self, data_id):
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
path = os.path.join(self.base_path, data[key])
data[key] = self.load_data(path)
if data[key] is None:
warnings.warn(f"cannot load file {data[key]}.")
return None
return data
def __len__(self):
return len(self.data) * self.repeat
class VideoDataset(torch.utils.data.Dataset):
def __init__(
self,
@@ -219,9 +340,10 @@ class DiffusionTrainingModule(torch.nn.Module):
class ModelLogger:
def __init__(self, output_path, remove_prefix_in_ckpt=None):
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
self.state_dict_converter = state_dict_converter
def on_step_end(self, loss):
@@ -233,6 +355,7 @@ class ModelLogger:
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
@@ -303,3 +426,30 @@ def wan_parser():
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
return parser
def flux_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution..")
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.")
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
return parser

View File

@@ -1 +1,2 @@
from .layers import *
from .gradient_checkpointing import *

View File

@@ -0,0 +1,34 @@
import torch
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
def gradient_checkpoint_forward(
model,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*args,
**kwargs,
):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
else:
model_output = model(*args, **kwargs)
return model_output