From fcf2fbc07fd235c3b7b9465108aca14165a260d4 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 27 Jun 2025 10:20:11 +0800 Subject: [PATCH] flux-refactor --- diffsynth/pipelines/flux_image_new.py | 25 ++- diffsynth/pipelines/wan_video_new.py | 2 +- diffsynth/trainers/utils.py | 152 +++++++++++++++++- diffsynth/vram_management/__init__.py | 1 + .../vram_management/gradient_checkpointing.py | 34 ++++ download.py | 3 + .../{Flex.2-preview.py => FLEX.2-preview.py} | 0 .../full/FLUX.1-dev-IP-Adapter.sh | 14 ++ .../flux/model_training/full/FLUX.1-dev.sh | 12 ++ .../full/accelerate_config.yaml | 22 +++ .../flux/model_training/lora/FLUX.1-dev.sh | 15 ++ examples/flux/model_training/train.py | 117 ++++++++++++++ .../validate_full/FLUX.1-dev-IP-Adapter.py | 28 ++++ .../validate_full/FLUX.1-dev.py | 20 +++ .../validate_lora/FLUX.1-dev.py | 18 +++ 15 files changed, 456 insertions(+), 7 deletions(-) create mode 100644 diffsynth/vram_management/gradient_checkpointing.py create mode 100644 download.py rename examples/flux/model_inference/{Flex.2-preview.py => FLEX.2-preview.py} (100%) create mode 100644 examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh create mode 100644 examples/flux/model_training/full/FLUX.1-dev.sh create mode 100644 examples/flux/model_training/full/accelerate_config.yaml create mode 100644 examples/flux/model_training/lora/FLUX.1-dev.sh create mode 100644 examples/flux/model_training/train.py create mode 100644 examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py create mode 100644 examples/flux/model_training/validate_full/FLUX.1-dev.py create mode 100644 examples/flux/model_training/validate_lora/FLUX.1-dev.py diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index 5dd553a..3bf971d 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -18,10 +18,13 @@ from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEnc from ..models.step1x_connector import Qwen2Connector from ..models.flux_controlnet import FluxControlNet from ..models.flux_ipadapter import FluxIpAdapter +from ..models.flux_infiniteyou import InfiniteYouImageProjector from ..models.tiler import FastTileWorker from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit from ..lora.flux_lora import FluxLoRALoader +from ..vram_management import gradient_checkpoint_forward + @dataclass @@ -89,6 +92,8 @@ class FluxImagePipeline(BasePipeline): self.unit_runner = PipelineUnitRunner() self.qwenvl = None self.step1x_connector: Qwen2Connector = None + self.infinityou_processor: InfinitYou = None + self.image_proj_model: InfiniteYouImageProjector = None self.in_iteration_models = ("dit", "step1x_connector", "controlnet") self.units = [ FluxImageUnit_ShapeChecker(), @@ -209,7 +214,7 @@ class FluxImagePipeline(BasePipeline): # ControlNet controlnet_inputs: list[ControlNetInput] = None, # IP-Adapter - ipadapter_images: list[Image.Image] = None, + ipadapter_images: Union[list[Image.Image], Image.Image] = None, ipadapter_scale: float = 1.0, # EliGen eligen_entity_prompts: list[str] = None, @@ -426,6 +431,8 @@ class FluxImageUnit_IPAdapter(PipelineUnit): ipadapter_images, ipadapter_scale = inputs_shared.get("ipadapter_images", None), inputs_shared.get("ipadapter_scale", 1.0) if ipadapter_images is None: return inputs_shared, inputs_posi, inputs_nega + if not isinstance(ipadapter_images, list): + ipadapter_images = [ipadapter_images] pipe.load_models_to_device(self.onload_model_names) images = [image.convert("RGB").resize((384, 384), resample=3) for image in ipadapter_images] @@ -700,6 +707,8 @@ def model_fn_flux_image( tea_cache: TeaCache = None, progress_id=0, num_inference_steps=1, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, **kwargs ): if tiled: @@ -805,13 +814,16 @@ def model_fn_flux_image( else: # Joint Blocks for block_id, block in enumerate(dit.blocks): - hidden_states, prompt_emb = block( + hidden_states, prompt_emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None) + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None), ) # ControlNet if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None: @@ -821,13 +833,16 @@ def model_fn_flux_image( hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) num_joint_blocks = len(dit.blocks) for block_id, block in enumerate(dit.single_blocks): - hidden_states, prompt_emb = block( + hidden_states, prompt_emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None) + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), ) # ControlNet if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None: diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index a9d9ad1..14e564c 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -178,7 +178,7 @@ class ModelConfig: skip_download = dist.get_rank() != 0 # Check whether the origin path is a folder - if self.origin_file_pattern is None: + if self.origin_file_pattern is None or self.origin_file_pattern == "": self.origin_file_pattern = "" allow_file_pattern = None is_folder = True diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index d627dab..7af03e6 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -7,6 +7,127 @@ from accelerate import Accelerator +class ImageDataset(torch.utils.data.Dataset): + def __init__( + self, + base_path=None, metadata_path=None, + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + data_file_keys=("image",), + image_file_extension=("jpg", "jpeg", "png", "webp"), + repeat=1, + args=None, + ): + if args is not None: + base_path = args.dataset_base_path + metadata_path = args.dataset_metadata_path + height = args.height + width = args.width + max_pixels = args.max_pixels + data_file_keys = args.data_file_keys.split(",") + repeat = args.dataset_repeat + + self.base_path = base_path + self.max_pixels = max_pixels + self.height = height + self.width = width + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + self.data_file_keys = data_file_keys + self.image_file_extension = image_file_extension + self.repeat = repeat + + if height is not None and width is not None: + print("Height and width are fixed. Setting `dynamic_resolution` to False.") + self.dynamic_resolution = False + elif height is None and width is None: + print("Height and width are none. Setting `dynamic_resolution` to True.") + self.dynamic_resolution = True + + if metadata_path is None: + print("No metadata. Trying to generate it.") + metadata = self.generate_metadata(base_path) + print(f"{len(metadata)} lines in metadata.") + else: + metadata = pd.read_csv(metadata_path) + self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + + + def generate_metadata(self, folder): + image_list, prompt_list = [], [] + file_set = set(os.listdir(folder)) + for file_name in file_set: + if "." not in file_name: + continue + file_ext_name = file_name.split(".")[-1].lower() + file_base_name = file_name[:-len(file_ext_name)-1] + if file_ext_name not in self.image_file_extension: + continue + prompt_file_name = file_base_name + ".txt" + if prompt_file_name not in file_set: + continue + with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f: + prompt = f.read().strip() + image_list.append(file_name) + prompt_list.append(prompt) + metadata = pd.DataFrame() + metadata["image"] = image_list + metadata["prompt"] = prompt_list + return metadata + + + def crop_and_resize(self, image, target_height, target_width): + width, height = image.size + scale = max(target_width / width, target_height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) + return image + + + def get_height_width(self, image): + if self.dynamic_resolution: + width, height = image.size + if width * height > self.max_pixels: + scale = (width * height / self.max_pixels) ** 0.5 + height, width = int(height / scale), int(width / scale) + height = height // self.height_division_factor * self.height_division_factor + width = width // self.width_division_factor * self.width_division_factor + else: + height, width = self.height, self.width + return height, width + + + def load_image(self, file_path): + image = Image.open(file_path).convert("RGB") + image = self.crop_and_resize(image, *self.get_height_width(image)) + return image + + + def load_data(self, file_path): + return self.load_image(file_path) + + + def __getitem__(self, data_id): + data = self.data[data_id % len(self.data)].copy() + for key in self.data_file_keys: + if key in data: + path = os.path.join(self.base_path, data[key]) + data[key] = self.load_data(path) + if data[key] is None: + warnings.warn(f"cannot load file {data[key]}.") + return None + return data + + + def __len__(self): + return len(self.data) * self.repeat + + + class VideoDataset(torch.utils.data.Dataset): def __init__( self, @@ -218,9 +339,10 @@ class DiffusionTrainingModule(torch.nn.Module): class ModelLogger: - def __init__(self, output_path, remove_prefix_in_ckpt=None): + def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x): self.output_path = output_path self.remove_prefix_in_ckpt = remove_prefix_in_ckpt + self.state_dict_converter = state_dict_converter def on_step_end(self, loss): @@ -232,6 +354,7 @@ class ModelLogger: if accelerator.is_main_process: state_dict = accelerator.get_state_dict(model) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) + state_dict = self.state_dict_converter(state_dict) os.makedirs(self.output_path, exist_ok=True) path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") accelerator.save(state_dict, path, safe_serialization=True) @@ -302,3 +425,30 @@ def wan_parser(): parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") return parser + + +def flux_parser(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.") + parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.") + parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution..") + parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.") + parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.") + parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.") + parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.") + parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") + parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") + parser.add_argument("--output_path", type=str, default="./models", help="Output save path.") + parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.") + parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.") + parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.") + parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") + parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") + parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") + parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.") + parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") + parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + return parser diff --git a/diffsynth/vram_management/__init__.py b/diffsynth/vram_management/__init__.py index 69a388d..5b07580 100644 --- a/diffsynth/vram_management/__init__.py +++ b/diffsynth/vram_management/__init__.py @@ -1 +1,2 @@ from .layers import * +from .gradient_checkpointing import * diff --git a/diffsynth/vram_management/gradient_checkpointing.py b/diffsynth/vram_management/gradient_checkpointing.py new file mode 100644 index 0000000..b356415 --- /dev/null +++ b/diffsynth/vram_management/gradient_checkpointing.py @@ -0,0 +1,34 @@ +import torch + + +def create_custom_forward(module): + def custom_forward(*inputs, **kwargs): + return module(*inputs, **kwargs) + return custom_forward + + +def gradient_checkpoint_forward( + model, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + *args, + **kwargs, +): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=False, + ) + else: + model_output = model(*args, **kwargs) + return model_output diff --git a/download.py b/download.py new file mode 100644 index 0000000..1404010 --- /dev/null +++ b/download.py @@ -0,0 +1,3 @@ +#模型下载 +from modelscope import snapshot_download +model_dir = snapshot_download('black-forest-labs/FLUX.1-Kontext-dev', cache_dir="models", ignore_file_pattern="transformer/*") \ No newline at end of file diff --git a/examples/flux/model_inference/Flex.2-preview.py b/examples/flux/model_inference/FLEX.2-preview.py similarity index 100% rename from examples/flux/model_inference/Flex.2-preview.py rename to examples/flux/model_inference/FLEX.2-preview.py diff --git a/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh b/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh new file mode 100644 index 0000000..43bc006 --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh @@ -0,0 +1,14 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_ipadapter.csv \ + --data_file_keys "image,ipadapter_images" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.ipadapter." \ + --output_path "./models/train/FLUX.1-dev-IP-Adapter_full" \ + --trainable_models "ipadapter" \ + --extra_inputs "ipadapter_images" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/full/FLUX.1-dev.sh b/examples/flux/model_training/full/FLUX.1-dev.sh new file mode 100644 index 0000000..9254957 --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-dev.sh @@ -0,0 +1,12 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-dev_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/full/accelerate_config.yaml b/examples/flux/model_training/full/accelerate_config.yaml new file mode 100644 index 0000000..83280f7 --- /dev/null +++ b/examples/flux/model_training/full/accelerate_config.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/flux/model_training/lora/FLUX.1-dev.sh b/examples/flux/model_training/lora/FLUX.1-dev.sh new file mode 100644 index 0000000..4b207ef --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-dev.sh @@ -0,0 +1,15 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-dev_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py new file mode 100644 index 0000000..6717c9e --- /dev/null +++ b/examples/flux/model_training/train.py @@ -0,0 +1,117 @@ +import torch, os, json +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser +from diffsynth.models.lora import FluxLoRAConverter +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + + +class FluxTrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + ): + super().__init__() + # Load models + model_configs = [] + if model_paths is not None: + model_paths = json.loads(model_paths) + model_configs += [ModelConfig(path=path) for path in model_paths] + if model_id_with_origin_paths is not None: + model_id_with_origin_paths = model_id_with_origin_paths.split(",") + model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] + self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) + + # Reset training scheduler + self.pipe.scheduler.set_timesteps(1000, training=True) + + # Freeze untrainable models + self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) + + # Add LoRA to the base models + if lora_base_model is not None: + model = self.add_lora_to_model( + getattr(self.pipe, lora_base_model), + target_modules=lora_target_modules.split(","), + lora_rank=lora_rank + ) + setattr(self.pipe, lora_base_model, model) + + # Store other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + + + def forward_preprocess(self, data): + # CFG-sensitive parameters + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {} + + # CFG-unsensitive parameters + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "embedded_guidance": 1, + "t5_sequence_length": 512, + "tiled": False, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + } + + # Extra inputs + for extra_input in self.extra_inputs: + inputs_shared[extra_input] = data[extra_input] + + # Pipeline units will automatically process the input parameters. + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + return {**inputs_shared, **inputs_posi} + + + def forward(self, data, inputs=None): + if inputs is None: inputs = self.forward_preprocess(data) + models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} + loss = self.pipe.training_loss(**models, **inputs) + return loss + + + +if __name__ == "__main__": + parser = flux_parser() + args = parser.parse_args() + dataset = ImageDataset(args=args) + model = FluxTrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + ) + model_logger = ModelLogger( + args.output_path, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x, + ) + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) + scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) + launch_training_task( + dataset, model, model_logger, optimizer, scheduler, + num_epochs=args.num_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) diff --git a/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py b/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py new file mode 100644 index 0000000..b6bab3d --- /dev/null +++ b/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py @@ -0,0 +1,28 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"), + ModelConfig(model_id="google/siglip-so400m-patch14-384"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-dev-IP-Adapter_full/epoch-0.safetensors") +pipe.ipadapter.load_state_dict(state_dict) + +image = pipe( + prompt="a dog", + ipadapter_images=Image.open("data/example_image_dataset/1.jpg"), + height=768, width=768, + seed=0 +) +image.save("image_FLUX.1-dev-IP-Adapter_full.jpg") diff --git a/examples/flux/model_training/validate_full/FLUX.1-dev.py b/examples/flux/model_training/validate_full/FLUX.1-dev.py new file mode 100644 index 0000000..d3adf7a --- /dev/null +++ b/examples/flux/model_training/validate_full/FLUX.1-dev.py @@ -0,0 +1,20 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-dev_full/epoch-0.safetensors") +pipe.dit.load_state_dict(state_dict) + +image = pipe(prompt="a dog", seed=0) +image.save("image_FLUX.1-dev_full.jpg") diff --git a/examples/flux/model_training/validate_lora/FLUX.1-dev.py b/examples/flux/model_training/validate_lora/FLUX.1-dev.py new file mode 100644 index 0000000..d1aebef --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-dev.py @@ -0,0 +1,18 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev_lora/epoch-4.safetensors", alpha=1) + +image = pipe(prompt="a dog", seed=0) +image.save("image_FLUX.1-dev_lora.jpg")