Merge pull request #646 from modelscope/flux-refactor

Flux refactor
This commit is contained in:
Zhongjie Duan
2025-06-29 18:04:05 +08:00
committed by GitHub
29 changed files with 2706 additions and 4 deletions

View File

@@ -7,6 +7,127 @@ from accelerate import Accelerator
class ImageDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path=None, metadata_path=None,
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
data_file_keys=("image",),
image_file_extension=("jpg", "jpeg", "png", "webp"),
repeat=1,
args=None,
):
if args is not None:
base_path = args.dataset_base_path
metadata_path = args.dataset_metadata_path
height = args.height
width = args.width
max_pixels = args.max_pixels
data_file_keys = args.data_file_keys.split(",")
repeat = args.dataset_repeat
self.base_path = base_path
self.max_pixels = max_pixels
self.height = height
self.width = width
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.data_file_keys = data_file_keys
self.image_file_extension = image_file_extension
self.repeat = repeat
if height is not None and width is not None:
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
self.dynamic_resolution = False
elif height is None and width is None:
print("Height and width are none. Setting `dynamic_resolution` to True.")
self.dynamic_resolution = True
if metadata_path is None:
print("No metadata. Trying to generate it.")
metadata = self.generate_metadata(base_path)
print(f"{len(metadata)} lines in metadata.")
else:
metadata = pd.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
def generate_metadata(self, folder):
image_list, prompt_list = [], []
file_set = set(os.listdir(folder))
for file_name in file_set:
if "." not in file_name:
continue
file_ext_name = file_name.split(".")[-1].lower()
file_base_name = file_name[:-len(file_ext_name)-1]
if file_ext_name not in self.image_file_extension:
continue
prompt_file_name = file_base_name + ".txt"
if prompt_file_name not in file_set:
continue
with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f:
prompt = f.read().strip()
image_list.append(file_name)
prompt_list.append(prompt)
metadata = pd.DataFrame()
metadata["image"] = image_list
metadata["prompt"] = prompt_list
return metadata
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def get_height_width(self, image):
if self.dynamic_resolution:
width, height = image.size
if width * height > self.max_pixels:
scale = (width * height / self.max_pixels) ** 0.5
height, width = int(height / scale), int(width / scale)
height = height // self.height_division_factor * self.height_division_factor
width = width // self.width_division_factor * self.width_division_factor
else:
height, width = self.height, self.width
return height, width
def load_image(self, file_path):
image = Image.open(file_path).convert("RGB")
image = self.crop_and_resize(image, *self.get_height_width(image))
return image
def load_data(self, file_path):
return self.load_image(file_path)
def __getitem__(self, data_id):
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
path = os.path.join(self.base_path, data[key])
data[key] = self.load_data(path)
if data[key] is None:
warnings.warn(f"cannot load file {data[key]}.")
return None
return data
def __len__(self):
return len(self.data) * self.repeat
class VideoDataset(torch.utils.data.Dataset):
def __init__(
self,
@@ -219,9 +340,10 @@ class DiffusionTrainingModule(torch.nn.Module):
class ModelLogger:
def __init__(self, output_path, remove_prefix_in_ckpt=None):
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
self.state_dict_converter = state_dict_converter
def on_step_end(self, loss):
@@ -233,6 +355,7 @@ class ModelLogger:
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
@@ -303,3 +426,30 @@ def wan_parser():
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
return parser
def flux_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution..")
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.")
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.")
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
return parser