From d21676b4dcb3126b5511683e22c39895f2d98463 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 11 Apr 2025 15:31:30 +0800 Subject: [PATCH] support reference image --- diffsynth/data/image_pulse.py | 125 ++++++++++++ diffsynth/models/flux_reference_embedder.py | 20 ++ diffsynth/pipelines/flux_image.py | 100 ++++++++-- train_flux_reference.py | 204 ++++++++++++++++++++ 4 files changed, 429 insertions(+), 20 deletions(-) create mode 100644 diffsynth/data/image_pulse.py create mode 100644 diffsynth/models/flux_reference_embedder.py create mode 100644 train_flux_reference.py diff --git a/diffsynth/data/image_pulse.py b/diffsynth/data/image_pulse.py new file mode 100644 index 0000000..0ef9236 --- /dev/null +++ b/diffsynth/data/image_pulse.py @@ -0,0 +1,125 @@ +import torch, os, json, torchvision +from PIL import Image +from torchvision.transforms import v2 + + + +class SingleTaskDataset(torch.utils.data.Dataset): + def __init__(self, base_path, keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")), height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None): + self.base_path = base_path + self.keys = keys + self.metadata = [] + self.bad_data = [] + self.height = height + self.width = width + self.random = random + self.steps_per_epoch = steps_per_epoch + self.image_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + if metadata_path is None: + self.search_for_data("", report_data_log=True) + self.report_data_log() + else: + with open(metadata_path, "r", encoding="utf-8-sig") as f: + self.metadata = json.load(f) + + + def report_data_log(self): + print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.") + + + def dump_metadata(self, path): + with open(path, "w", encoding="utf-8") as f: + json.dump(self.metadata, f, ensure_ascii=False, indent=4) + + + def parse_json_file(self, absolute_path, relative_path): + data_list = [] + with open(absolute_path, "r") as f: + metadata = json.load(f) + for image_1, image_2, instruction in self.keys: + image_1 = os.path.join(relative_path, metadata[image_1]) + image_2 = os.path.join(relative_path, metadata[image_2]) + instruction = metadata[instruction] + data_list.append((image_1, image_2, instruction)) + return data_list + + + def search_for_data(self, path, report_data_log=False): + now_path = os.path.join(self.base_path, path) + if os.path.isfile(now_path) and path.endswith(".json"): + try: + data_list = self.parse_json_file(now_path, os.path.dirname(path)) + self.metadata.extend(data_list) + except: + self.bad_data.append(now_path) + elif os.path.isdir(now_path): + for sub_path in os.listdir(now_path): + self.search_for_data(os.path.join(path, sub_path)) + if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)): + self.report_data_log() + + + def load_image(self, image_path): + image_path = os.path.join(self.base_path, image_path) + image = Image.open(image_path).convert("RGB") + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = self.image_process(image) + return image + + + def load_data(self, data_id): + image_1, image_2, instruction = self.metadata[data_id] + image_1 = self.load_image(image_1) + image_2 = self.load_image(image_2) + return {"image_1": image_1, "image_2": image_2, "instruction": instruction} + + + def __getitem__(self, data_id): + if self.random: + while True: + try: + data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata) + data = self.load_data(data_id) + return data + except: + continue + else: + return self.load_data(data_id) + + + def __len__(self): + return self.steps_per_epoch if self.random else len(self.metadata) + + + +class MultiTaskDataset(torch.utils.data.Dataset): + def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000): + self.dataset_list = dataset_list + self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float) + self.steps_per_epoch = steps_per_epoch + + + def __getitem__(self, data_id): + while True: + try: + dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0] + data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0] + data = self.dataset_list[dataset_id][data_id] + return data + except: + continue + + + def __len__(self): + return self.steps_per_epoch diff --git a/diffsynth/models/flux_reference_embedder.py b/diffsynth/models/flux_reference_embedder.py new file mode 100644 index 0000000..994ffaa --- /dev/null +++ b/diffsynth/models/flux_reference_embedder.py @@ -0,0 +1,20 @@ +from .sd3_dit import TimestepEmbeddings +from .flux_dit import RoPEEmbedding +import torch +from einops import repeat + + +class FluxReferenceEmbedder(torch.nn.Module): + def __init__(self): + super().__init__() + self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) + self.idx_embedder = TimestepEmbeddings(256, 256) + + def forward(self, image_ids, idx, dtype): + pos_emb = self.pos_embedder(image_ids) + idx_emb = self.idx_embedder(idx, dtype=dtype) + length = pos_emb.shape[2] + pos_emb = repeat(pos_emb, "B N L C H W -> 1 N (B L) C H W") + idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length) + image_rotary_emb = pos_emb + idx_emb + return image_rotary_emb diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index c0729fc..bf41c7d 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -1,10 +1,12 @@ from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator +from ..models.flux_reference_embedder import FluxReferenceEmbedder from ..prompters import FluxPrompter from ..schedulers import FlowMatchScheduler from .base import BasePipeline from typing import List import torch +from einops import rearrange from tqdm import tqdm import numpy as np from PIL import Image @@ -32,6 +34,7 @@ class FluxImagePipeline(BasePipeline): self.ipadapter: FluxIpAdapter = None self.ipadapter_image_encoder: SiglipVisionModel = None self.infinityou_processor: InfinitYou = None + self.reference_embedder: FluxReferenceEmbedder = None self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder'] @@ -360,6 +363,20 @@ class FluxImagePipeline(BasePipeline): return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width) else: return {}, controlnet_image + + + def prepare_reference_images(self, reference_images, tiled=False, tile_size=64, tile_stride=32): + if reference_images is not None: + hidden_states_ref = [] + for reference_image in reference_images: + self.load_models_to_device(['vae_encoder']) + reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype) + latents = self.encode_image(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + hidden_states_ref.append(latents) + hidden_states_ref = torch.concat(hidden_states_ref, dim=0) + return {"hidden_states_ref": hidden_states_ref} + else: + return {"hidden_states_ref": None} @torch.no_grad() @@ -398,6 +415,8 @@ class FluxImagePipeline(BasePipeline): # InfiniteYou infinityou_id_image=None, infinityou_guidance=1.0, + # Reference images + reference_images=None, # TeaCache tea_cache_l1_thresh=None, # Tile @@ -436,6 +455,9 @@ class FluxImagePipeline(BasePipeline): # ControlNets controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative) + + # Reference images + reference_kwargs = self.prepare_reference_images(reference_images, **tiler_kwargs) # TeaCache tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None} @@ -447,9 +469,9 @@ class FluxImagePipeline(BasePipeline): # Positive side inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( - dit=self.dit, controlnet=self.controlnet, + dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder, hidden_states=latents, timestep=timestep, - **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs + **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **reference_kwargs, ) noise_pred_posi = self.control_noise_via_local_prompts( prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, @@ -464,9 +486,9 @@ class FluxImagePipeline(BasePipeline): if cfg_scale != 1.0: # Negative side noise_pred_nega = lets_dance_flux( - dit=self.dit, controlnet=self.controlnet, + dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder, hidden_states=latents, timestep=timestep, - **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, + **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **reference_kwargs, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -586,6 +608,7 @@ class TeaCache: def lets_dance_flux( dit: FluxDiT, controlnet: FluxMultiControlNetManager = None, + reference_embedder: FluxReferenceEmbedder = None, hidden_states=None, timestep=None, prompt_emb=None, @@ -594,6 +617,7 @@ def lets_dance_flux( text_ids=None, image_ids=None, controlnet_frames=None, + hidden_states_ref=None, tiled=False, tile_size=128, tile_stride=64, @@ -603,6 +627,7 @@ def lets_dance_flux( id_emb=None, infinityou_guidance=None, tea_cache: TeaCache = None, + use_gradient_checkpointing=False, **kwargs ): if tiled: @@ -671,26 +696,52 @@ def lets_dance_flux( prompt_emb = dit.context_embedder(prompt_emb) image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) attention_mask = None + + # Reference images + if hidden_states_ref is not None: + # RoPE + image_ids_ref = dit.prepare_image_ids(hidden_states_ref) + idx = torch.arange(0, image_ids_ref.shape[0]).to(dtype=hidden_states.dtype, device=hidden_states.device) * 100 + image_rotary_emb_ref = reference_embedder(image_ids_ref, idx, dtype=hidden_states.dtype) + image_rotary_emb = torch.cat((image_rotary_emb, image_rotary_emb_ref), dim=2) + # hidden_states + original_hidden_states_length = hidden_states.shape[1] + hidden_states_ref = dit.patchify(hidden_states_ref) + hidden_states_ref = dit.x_embedder(hidden_states_ref) + hidden_states_ref = rearrange(hidden_states_ref, "B L C -> 1 (B L) C") + hidden_states = torch.cat((hidden_states, hidden_states_ref), dim=1) # TeaCache if tea_cache is not None: tea_cache_update = tea_cache.check(dit, hidden_states, conditioning) else: tea_cache_update = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward if tea_cache_update: hidden_states = tea_cache.update(hidden_states) else: # Joint Blocks for block_id, block in enumerate(dit.blocks): - hidden_states, prompt_emb = block( - hidden_states, - prompt_emb, - conditioning, - image_rotary_emb, - attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None) - ) + if use_gradient_checkpointing: + hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None), + use_reentrant=False, + ) + else: + hidden_states, prompt_emb = block( + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None) + ) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states = hidden_states + controlnet_res_stack[block_id] @@ -699,14 +750,21 @@ def lets_dance_flux( hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) num_joint_blocks = len(dit.blocks) for block_id, block in enumerate(dit.single_blocks): - hidden_states, prompt_emb = block( - hidden_states, - prompt_emb, - conditioning, - image_rotary_emb, - attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None) - ) + if use_gradient_checkpointing: + hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), + use_reentrant=False, + ) + else: + hidden_states, prompt_emb = block( + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None) + ) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] @@ -715,6 +773,8 @@ def lets_dance_flux( if tea_cache is not None: tea_cache.store(hidden_states) + if hidden_states_ref is not None: + hidden_states = hidden_states[:, :original_hidden_states_length] hidden_states = dit.final_norm_out(hidden_states, conditioning) hidden_states = dit.final_proj_out(hidden_states) hidden_states = dit.unpatchify(hidden_states, height, width) diff --git a/train_flux_reference.py b/train_flux_reference.py new file mode 100644 index 0000000..a6ec6c7 --- /dev/null +++ b/train_flux_reference.py @@ -0,0 +1,204 @@ +from diffsynth import ModelManager, FluxImagePipeline +from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task +from diffsynth.models.lora import FluxLoRAConverter +import torch, os, argparse +import lightning as pl +from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset +from diffsynth.pipelines.flux_image import lets_dance_flux +os.environ["TOKENIZERS_PARALLELISM"] = "True" + + +class LightningModel(LightningModelForT2ILoRA): + def __init__( + self, + torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None, + learning_rate=1e-4, use_gradient_checkpointing=True, + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None, + state_dict_converter=None, quantize = None + ): + super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter) + # Load models + model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device) + if quantize is None: + model_manager.load_models(pretrained_weights) + else: + model_manager.load_models(pretrained_weights[1:]) + model_manager.load_model(pretrained_weights[0], torch_dtype=quantize) + if preset_lora_path is not None: + preset_lora_path = preset_lora_path.split(",") + for path in preset_lora_path: + model_manager.load_lora(path) + + self.pipe = FluxImagePipeline.from_model_manager(model_manager) + + if quantize is not None: + self.pipe.dit.quantize() + + self.pipe.scheduler.set_timesteps(1000, training=True) + + self.freeze_parameters() + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, + state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format + ) + + def training_step(self, batch, batch_idx): + # Data + text, image = batch["text"], batch["image_2"] + image_ref = batch["image_1"] + + # Prepare input parameters + self.pipe.device = self.device + prompt_emb = self.pipe.encode_prompt(text, positive=True) + if "latents" in batch: + latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device) + else: + latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device)) + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device) + extra_input = self.pipe.prepare_extra_input(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Reference image + hidden_states_ref = self.pipe.vae_encoder(image_ref.to(dtype=self.pipe.torch_dtype, device=self.device)) + + # Compute loss + noise_pred = lets_dance_flux( + self.pipe.denoising_model(), + hidden_states_ref=hidden_states_ref, + latents=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, + use_gradient_checkpointing=self.use_gradient_checkpointing + ) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + # Record log + self.log("train_loss", loss, prog_bar=True) + return loss + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_text_encoder_path", + type=str, + default=None, + required=True, + help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.", + ) + parser.add_argument( + "--pretrained_text_encoder_2_path", + type=str, + default=None, + required=True, + help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.", + ) + parser.add_argument( + "--pretrained_dit_path", + type=str, + default=None, + required=True, + help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.", + ) + parser.add_argument( + "--pretrained_vae_path", + type=str, + default=None, + required=True, + help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", + help="Layers with LoRA modules.", + ) + parser.add_argument( + "--align_to_opensource_format", + default=False, + action="store_true", + help="Whether to export lora files aligned with other opensource format.", + ) + parser.add_argument( + "--quantize", + type=str, + default=None, + choices=["float8_e4m3fn"], + help="Whether to use quantization when training the model, and in which format.", + ) + parser.add_argument( + "--preset_lora_path", + type=str, + default=None, + help="Preset LoRA path.", + ) + parser = add_general_parsers(parser) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + model = LightningModel( + torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16), + pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path], + preset_lora_path=args.preset_lora_path, + learning_rate=args.learning_rate, + use_gradient_checkpointing=args.use_gradient_checkpointing, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_target_modules=args.lora_target_modules, + init_lora_weights=args.init_lora_weights, + pretrained_lora_path=args.pretrained_lora_path, + state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None, + quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None), + ) + # dataset and data loader + dataset = MultiTaskDataset( + dataset_list=[ + SingleTaskDataset( + "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove", + metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_change_add_remove.json", + ), + SingleTaskDataset( + "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout", + metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_zoomin_zoomout.json", + ), + SingleTaskDataset( + "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer", + keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")), + metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_style_transfer.json", + ), + SingleTaskDataset( + "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid", + metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_faceid.json", + ), + ], + dataset_weight=(4, 2, 2, 1), + ) + train_loader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=args.batch_size, + num_workers=args.dataloader_num_workers + ) + # train + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision=args.precision, + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=None, + ) + trainer.fit(model=model, train_dataloaders=train_loader)