From 44da204dbdb951b0683bddaf098403640ea91392 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 21 Apr 2025 15:48:25 +0800 Subject: [PATCH] lora merger --- diffsynth/models/utils.py | 13 +- diffsynth/pipelines/flux_image.py | 61 +++++--- diffsynth/vram_management/layers.py | 14 +- models/lora/Put lora files here.txt | 0 scripts/data_process.py | 85 ++++++++++++ scripts/test.py | 166 ++++++++++++++++++++++ scripts/train.py | 207 ++++++++++++++++++++++++++++ 7 files changed, 516 insertions(+), 30 deletions(-) delete mode 100644 models/lora/Put lora files here.txt create mode 100644 scripts/data_process.py create mode 100644 scripts/test.py create mode 100644 scripts/train.py diff --git a/diffsynth/models/utils.py b/diffsynth/models/utils.py index 99f5dee..1d658f2 100644 --- a/diffsynth/models/utils.py +++ b/diffsynth/models/utils.py @@ -62,25 +62,26 @@ def load_state_dict_from_folder(file_path, torch_dtype=None): return state_dict -def load_state_dict(file_path, torch_dtype=None): +def load_state_dict(file_path, torch_dtype=None, device="cpu"): if file_path.endswith(".safetensors"): - return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) + return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device) else: - return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) + return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) -def load_state_dict_from_safetensors(file_path, torch_dtype=None): +def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): state_dict = {} with safe_open(file_path, framework="pt", device="cpu") as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) if torch_dtype is not None: state_dict[k] = state_dict[k].to(torch_dtype) + state_dict[k] = state_dict[k].to(device) return state_dict -def load_state_dict_from_bin(file_path, torch_dtype=None): - state_dict = torch.load(file_path, map_location="cpu", weights_only=True) +def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"): + state_dict = torch.load(file_path, map_location=device, weights_only=True) if torch_dtype is not None: for i in state_dict: if isinstance(state_dict[i], torch.Tensor): diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 39b33f1..21d19c9 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -401,7 +401,8 @@ class FluxImagePipeline(BasePipeline): progress_bar_cmd=tqdm, progress_bar_st=None, lora_state_dicts=[], - lora_alpahs=[] + lora_alpahs=[], + lora_patcher=None, ): height, width = self.check_resize_height_width(height, width) @@ -443,6 +444,7 @@ class FluxImagePipeline(BasePipeline): hidden_states=latents, timestep=timestep, lora_state_dicts=lora_state_dicts, lora_alpahs = lora_alpahs, + lora_patcher=lora_patcher, **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, ) noise_pred_posi = self.control_noise_via_local_prompts( @@ -462,6 +464,7 @@ class FluxImagePipeline(BasePipeline): hidden_states=latents, timestep=timestep, lora_state_dicts=lora_state_dicts, lora_alpahs = lora_alpahs, + lora_patcher=lora_patcher, **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) @@ -544,6 +547,7 @@ def lets_dance_flux( entity_masks=None, ipadapter_kwargs_list={}, tea_cache: TeaCache = None, + use_gradient_checkpointing=False, **kwargs ): @@ -610,6 +614,11 @@ def lets_dance_flux( prompt_emb = dit.context_embedder(prompt_emb) image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) attention_mask = None + + def create_custom_forward(module): + def custom_forward(*inputs, **kwargs): + return module(*inputs, **kwargs) + return custom_forward # TeaCache if tea_cache is not None: @@ -622,15 +631,22 @@ def lets_dance_flux( else: # Joint Blocks for block_id, block in enumerate(dit.blocks): - hidden_states, prompt_emb = block( - hidden_states, - prompt_emb, - conditioning, - image_rotary_emb, - attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None), - **kwargs - ) + if use_gradient_checkpointing: + hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None), **kwargs, + use_reentrant=False, + ) + else: + hidden_states, prompt_emb = block( + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None), + **kwargs + ) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states = hidden_states + controlnet_res_stack[block_id] @@ -639,15 +655,22 @@ def lets_dance_flux( hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) num_joint_blocks = len(dit.blocks) for block_id, block in enumerate(dit.single_blocks): - hidden_states, prompt_emb = block( - hidden_states, - prompt_emb, - conditioning, - image_rotary_emb, - attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), - **kwargs - ) + if use_gradient_checkpointing: + hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), **kwargs, + use_reentrant=False, + ) + else: + hidden_states, prompt_emb = block( + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), + **kwargs + ) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index f3d7b7c..2e8c5e6 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -71,15 +71,16 @@ class AutoWrappedLinear(torch.nn.Linear): return torch.nn.functional.linear(x, weight, bias) class AutoLoRALinear(torch.nn.Linear): - def __init__(self, name='', in_features=1, out_features=2, bias = True, device=None, dtype=None): + def __init__(self, name='', in_features=1, out_features=2, bias=True, device=None, dtype=None): super().__init__(in_features, out_features, bias, device, dtype) self.name = name - def forward(self, x, lora_state_dicts=[], lora_alpahs=[1.0,1.0], **kwargs): + def forward(self, x, lora_state_dicts=[], lora_alpahs=[1.0,1.0], lora_patcher=None, **kwargs): out = torch.nn.functional.linear(x, self.weight, self.bias) - lora_a_name = f'{self.name}.lora_A.weight' - lora_b_name = f'{self.name}.lora_B.weight' + lora_a_name = f'{self.name}.lora_A.default.weight' + lora_b_name = f'{self.name}.lora_B.default.weight' + lora_output = [] for i, lora_state_dict in enumerate(lora_state_dicts): if lora_state_dict is None: break @@ -87,7 +88,10 @@ class AutoLoRALinear(torch.nn.Linear): lora_A = lora_state_dict[lora_a_name].to(dtype=self.weight.dtype,device=self.weight.device) lora_B = lora_state_dict[lora_b_name].to(dtype=self.weight.dtype,device=self.weight.device) out_lora = x @ lora_A.T @ lora_B.T - out = out + out_lora * lora_alpahs[i] + lora_output.append(out_lora) + if len(lora_output) > 0: + lora_output = torch.stack(lora_output) + out = lora_patcher(out, lora_output, self.name) return out def enable_auto_lora(model:torch.nn.Module, module_map: dict, name_prefix=''): diff --git a/models/lora/Put lora files here.txt b/models/lora/Put lora files here.txt deleted file mode 100644 index e69de29..0000000 diff --git a/scripts/data_process.py b/scripts/data_process.py new file mode 100644 index 0000000..44340e9 --- /dev/null +++ b/scripts/data_process.py @@ -0,0 +1,85 @@ +import torch, os, dashscope +import pandas as pd +from tqdm import tqdm +from diffsynth import load_state_dict, hash_state_dict_keys + + +def search_for_model_file(path, allow_file_extensions=(".safetensors",)): + for file_name in os.listdir(path): + for file_extension in allow_file_extensions: + if file_name.endswith(file_extension): + return os.path.join(path, file_name) + + +def search_for_cover_images(path, allow_file_extensions=(".png", ".jpg", ".jpeg")): + image_files = [] + for file_name in os.listdir(path): + for file_extension in allow_file_extensions: + if file_name.endswith(file_extension): + image_files.append(os.path.join(path, file_name)) + break + return image_files + + +def search_for_lora_data(path): + model_file = search_for_model_file(path) + if "_cover_images_" not in os.listdir(path): + return None + image_files = search_for_cover_images(os.path.join(path, "_cover_images_")) + if model_file is None or len(image_files) == 0: + return None + state_dict = load_state_dict(model_file) + if hash_state_dict_keys(state_dict, with_shape=False) != "52544ae3076666228978b738fbb8b086": + return None + return model_file, image_files + + +def image_to_text(images=[], prompt="", system_prompt=None): + dashscope.api_key = "xxxxx" # TODO + messages = [] + if system_prompt is not None: + messages.append({"role": "system", "content": system_prompt}) + if not isinstance(images, list): + images = [images] + messages.append({"role": "user", "content": [{"text": prompt}] + [{"image": image} for image in images]}) + response = dashscope.MultiModalConversation.call(model="qwen-vl-max-latest", messages=messages) + response = response["output"]["choices"][0]["message"]["content"][0]["text"] + return response + + +qwen_i2t_prompt = ''' +You are a professional image captioner. +Generate a caption according to the image so that another image generation model can generate the image via the caption. Just return the string description, do not return anything else. +'''.strip() + + +def data_to_csv(model_file_list, image_file_list, text_list, save_path): + data_df = pd.DataFrame() + data_df["model_file"] = model_file_list + data_df["image_file"] = image_file_list + data_df["text"] = text_list + data_df.to_csv(save_path, index=False, encoding="utf-8-sig") + + +base_path = "/data/zhiwen/LoRA-Fusion/models/FLUXLoRA" + +model_file_list = [] +image_file_list = [] +text_list = [] + +for lora_name in tqdm(os.listdir(base_path)): + lora_folder_path = os.path.join(base_path, lora_name) + if os.path.isdir(lora_folder_path): + data = search_for_lora_data(lora_folder_path) + if data is not None: + model_file, image_files = data + for image_file in image_files: + try: + text = image_to_text(image_file, prompt=qwen_i2t_prompt) + except: + continue + model_file_list.append(model_file) + image_file_list.append(image_file) + text_list.append(text) + data_to_csv(model_file_list, image_file_list, text_list, "data/loras.csv") + diff --git a/scripts/test.py b/scripts/test.py new file mode 100644 index 0000000..8edc327 --- /dev/null +++ b/scripts/test.py @@ -0,0 +1,166 @@ +import torch, shutil, os +from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict +from diffsynth.models.lora import LoRAFromCivitai, FluxLoRAConverter +import pandas as pd +import torch +import pandas as pd +from PIL import Image +import lightning as pl +from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict +from diffsynth.models.lora import LoRAFromCivitai, FluxLoRAConverter +from diffsynth.data.video import crop_and_resize +from diffsynth.pipelines.flux_image import lets_dance_flux +from torchvision.transforms import v2 + + +baseline = "trained" + + +class LoraMerger(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.weight_base = torch.nn.Parameter(torch.randn((dim,))) + self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) + self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) + self.weight_out = torch.nn.Parameter(torch.ones((dim,))) + self.bias = torch.nn.Parameter(torch.randn((dim,))) + self.activation = torch.nn.Sigmoid() + self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) + self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) + + def forward(self, base_output, lora_outputs): + global baseline + if baseline == "nolora": + output = base_output + elif baseline == "lora1": + output = base_output + lora_outputs[0] + elif baseline == "lora2": + output = base_output + lora_outputs[1] + elif baseline == "alllora": + output = base_output + lora_outputs.sum(dim=0) + else: + norm_base_output = self.norm_base(base_output) + norm_lora_outputs = self.norm_lora(lora_outputs) + gate = self.activation( + norm_base_output * self.weight_base \ + + norm_lora_outputs * self.weight_lora \ + + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias + ) + output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) + return output + + + +class LoraPatcher(torch.nn.Module): + def __init__(self, lora_patterns=None): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoraMerger(dim) + self.model_dict = torch.nn.ModuleDict(model_dict) + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, + "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + return lora_patterns + + def forward(self, base_output, lora_outputs, name): + return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) + + +class LoraDataset(torch.utils.data.Dataset): + def __init__(self, metadata_path, steps_per_epoch=1000): + data_df = pd.read_csv(metadata_path) + self.model_file = data_df["model_file"].tolist() + self.image_file = data_df["image_file"].tolist() + self.text = data_df["text"].tolist() + self.max_resolution = 1920 * 1080 + self.steps_per_epoch = steps_per_epoch + + + def read_image(self, image_file): + image = Image.open(image_file) + width, height = image.size + if width * height > self.max_resolution: + scale = (width * height / self.max_resolution) ** 0.5 + image = image.resize((int(width / scale), int(height / scale))) + width, height = image.size + if width % 16 != 0 or height % 16 != 0: + image = crop_and_resize(image, height // 16 * 16, width // 16 * 16) + image = v2.functional.to_image(image) + image = v2.functional.to_dtype(image, dtype=torch.float32, scale=True) + image = v2.functional.normalize(image, [0.5], [0.5]) + return image + + + def __getitem__(self, index): + data_id = torch.randint(0, len(self.model_file), (1,))[0] + data_id = (data_id + index) % len(self.model_file) # For fixed seed. + data_id_extra = torch.randint(0, len(self.model_file), (1,))[0] + return { + "model_file": self.model_file[data_id], + "model_file_extra": self.model_file[data_id_extra], + "image": self.read_image(self.image_file[data_id]), + "text": self.text[data_id] + } + + + def __len__(self): + return self.steps_per_epoch + + + + +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") +model_manager.load_models([ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" +]) +pipe = FluxImagePipeline.from_model_manager(model_manager) +pipe.enable_auto_lora() + + +lora_alpahs = [1, 1] +lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda") +lora_patcher.load_state_dict(load_state_dict("models/lightning_logs/version_13/checkpoints/epoch=2-step=1500.ckpt")) + +dataset = LoraDataset("data/loras_picked.csv") + +for seed in range(100): + data = dataset[0] + lora_state_dicts = [ + FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(data["model_file"], torch_dtype=torch.bfloat16, device="cuda")), + FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(data["model_file_extra"], torch_dtype=torch.bfloat16, device="cuda")), + ] + lora_alpahs = [1, 1] + for pattern in ["nolora", "lora1", "lora2", "alllora", "loramerger"]: + baseline = pattern + image = pipe( + prompt=data["text"], + lora_state_dicts=lora_state_dicts, + lora_alpahs=lora_alpahs, + lora_patcher=lora_patcher, + seed=seed, + ) + image.save(f"data/lora_outputs/image_{seed}_{pattern}.jpg") \ No newline at end of file diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000..2733851 --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,207 @@ +import torch +import pandas as pd +from PIL import Image +import lightning as pl +from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict +from diffsynth.models.lora import LoRAFromCivitai, FluxLoRAConverter +from diffsynth.data.video import crop_and_resize +from diffsynth.pipelines.flux_image import lets_dance_flux +from torchvision.transforms import v2 + + + +class LoraMerger(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.weight_base = torch.nn.Parameter(torch.randn((dim,))) + self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) + self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) + self.weight_out = torch.nn.Parameter(torch.ones((dim,))) + self.bias = torch.nn.Parameter(torch.randn((dim,))) + self.activation = torch.nn.Sigmoid() + self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) + self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) + + def forward(self, base_output, lora_outputs): + norm_base_output = self.norm_base(base_output) + norm_lora_outputs = self.norm_lora(lora_outputs) + gate = self.activation( + norm_base_output * self.weight_base \ + + norm_lora_outputs * self.weight_lora \ + + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias + ) + output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) + return output + + + +class LoraPatcher(torch.nn.Module): + def __init__(self, lora_patterns=None): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoraMerger(dim) + self.model_dict = torch.nn.ModuleDict(model_dict) + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, + "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + return lora_patterns + + def forward(self, base_output, lora_outputs, name): + return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) + + + +class LoraDataset(torch.utils.data.Dataset): + def __init__(self, metadata_path, steps_per_epoch=1000): + data_df = pd.read_csv(metadata_path) + self.model_file = data_df["model_file"].tolist() + self.image_file = data_df["image_file"].tolist() + self.text = data_df["text"].tolist() + self.max_resolution = 1920 * 1080 + self.steps_per_epoch = steps_per_epoch + + + def read_image(self, image_file): + image = Image.open(image_file) + width, height = image.size + if width * height > self.max_resolution: + scale = (width * height / self.max_resolution) ** 0.5 + image = image.resize((int(width / scale), int(height / scale))) + width, height = image.size + if width % 16 != 0 or height % 16 != 0: + image = crop_and_resize(image, height // 16 * 16, width // 16 * 16) + image = v2.functional.to_image(image) + image = v2.functional.to_dtype(image, dtype=torch.float32, scale=True) + image = v2.functional.normalize(image, [0.5], [0.5]) + return image + + + def __getitem__(self, index): + data_id = torch.randint(0, len(self.model_file), (1,))[0] + data_id = (data_id + index) % len(self.model_file) # For fixed seed. + data_id_extra = torch.randint(0, len(self.model_file), (1,))[0] + return { + "model_file": self.model_file[data_id], + "model_file_extra": self.model_file[data_id_extra], + "image": self.read_image(self.image_file[data_id]), + "text": self.text[data_id] + } + + + def __len__(self): + return self.steps_per_epoch + + + +class LightningModel(pl.LightningModule): + def __init__( + self, + learning_rate=1e-4, + use_gradient_checkpointing=True, + state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format, + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device=self.device, model_id_list=["FLUX.1-dev"]) + self.pipe = FluxImagePipeline.from_model_manager(model_manager) + self.lora_patcher = LoraPatcher() + self.pipe.enable_auto_lora() + self.pipe.scheduler.set_timesteps(1000, training=True) + self.freeze_parameters() + # Set parameters + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.state_dict_converter = state_dict_converter + + + def freeze_parameters(self): + # Freeze parameters + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + + def training_step(self, batch, batch_idx): + # Data + text, image = batch["text"], batch["image"] + lora_state_dicts = [ + self.state_dict_converter(load_state_dict(batch["model_file"][0], torch_dtype=torch.bfloat16, device=self.device)), + self.state_dict_converter(load_state_dict(batch["model_file_extra"][0], torch_dtype=torch.bfloat16, device=self.device)), + ] + lora_alpahs = [1, 1] + + # Prepare input parameters + self.pipe.device = self.device + prompt_emb = self.pipe.encode_prompt(text, positive=True) + if "latents" in batch: + latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device) + else: + latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device)) + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device) + extra_input = self.pipe.prepare_extra_input(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Compute loss + noise_pred = lets_dance_flux( + self.pipe.dit, + hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, + lora_state_dicts=lora_state_dicts, lora_alpahs=lora_alpahs, lora_patcher=self.lora_patcher, + use_gradient_checkpointing=self.use_gradient_checkpointing + ) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + # Record log + self.log("train_loss", loss, prog_bar=True) + return loss + + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.lora_patcher.parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint.clear() + checkpoint.update(self.lora_patcher.state_dict()) + + +if __name__ == '__main__': + model = LightningModel(learning_rate=1e-4) + dataset = LoraDataset("data/loras.csv", steps_per_epoch=500) + train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1) + trainer = pl.Trainer( + max_epochs=100000, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy="auto", + default_root_dir="./models", + accumulate_grad_batches=1, + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + ) + trainer.fit(model=model, train_dataloaders=train_loader)