From 50d2c86ae569f3ba57c9850af0effd0d08f8e96e Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 23 Jun 2025 17:34:30 +0800 Subject: [PATCH] lora retrieval --- diffsynth/pipelines/flux_image.py | 6 +- diffsynth/vram_management/layers.py | 2 +- lora/__init__.py | 0 lora/dataset.py | 54 ++++++++ lora/merger.py | 61 ++++++++ lora/retriever.py | 149 ++++++++++++++++++++ lora/test_merger.py | 46 +++++++ lora/test_retriever.py | 148 ++++++++++++++++++++ lora/train_merger.py | 119 ++++++++++++++++ lora/train_retriever.py | 105 ++++++++++++++ lora/utils.py | 12 ++ scripts/data_process.py | 85 ------------ scripts/test.py | 166 ---------------------- scripts/train.py | 207 ---------------------------- 14 files changed, 698 insertions(+), 462 deletions(-) create mode 100644 lora/__init__.py create mode 100644 lora/dataset.py create mode 100644 lora/merger.py create mode 100644 lora/retriever.py create mode 100644 lora/test_merger.py create mode 100644 lora/test_retriever.py create mode 100644 lora/train_merger.py create mode 100644 lora/train_retriever.py create mode 100644 lora/utils.py delete mode 100644 scripts/data_process.py delete mode 100644 scripts/test.py delete mode 100644 scripts/train.py diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 21d19c9..2f4d10d 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -401,7 +401,7 @@ class FluxImagePipeline(BasePipeline): progress_bar_cmd=tqdm, progress_bar_st=None, lora_state_dicts=[], - lora_alpahs=[], + lora_alphas=[], lora_patcher=None, ): height, width = self.check_resize_height_width(height, width) @@ -443,7 +443,7 @@ class FluxImagePipeline(BasePipeline): dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, lora_state_dicts=lora_state_dicts, - lora_alpahs = lora_alpahs, + lora_alphas = lora_alphas, lora_patcher=lora_patcher, **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, ) @@ -463,7 +463,7 @@ class FluxImagePipeline(BasePipeline): dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, lora_state_dicts=lora_state_dicts, - lora_alpahs = lora_alpahs, + lora_alphas = lora_alphas, lora_patcher=lora_patcher, **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, ) diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index 2e8c5e6..e998cac 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -75,7 +75,7 @@ class AutoLoRALinear(torch.nn.Linear): super().__init__(in_features, out_features, bias, device, dtype) self.name = name - def forward(self, x, lora_state_dicts=[], lora_alpahs=[1.0,1.0], lora_patcher=None, **kwargs): + def forward(self, x, lora_state_dicts=[], lora_alphas=[1.0,1.0], lora_patcher=None, **kwargs): out = torch.nn.functional.linear(x, self.weight, self.bias) lora_a_name = f'{self.name}.lora_A.default.weight' lora_b_name = f'{self.name}.lora_B.default.weight' diff --git a/lora/__init__.py b/lora/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lora/dataset.py b/lora/dataset.py new file mode 100644 index 0000000..dc21adf --- /dev/null +++ b/lora/dataset.py @@ -0,0 +1,54 @@ +import torch, os +import pandas as pd +from PIL import Image +from torchvision.transforms import v2 +from diffsynth.data.video import crop_and_resize + + +class LoraDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, steps_per_epoch=1000, loras_per_item=1): + self.base_path = base_path + data_df = pd.read_csv(metadata_path) + self.model_file = data_df["model_file"].tolist() + self.image_file = data_df["image_file"].tolist() + self.text = data_df["text"].tolist() + self.max_resolution = 1920 * 1080 + self.steps_per_epoch = steps_per_epoch + self.loras_per_item = loras_per_item + + + def read_image(self, image_file): + image = Image.open(image_file).convert("RGB") + width, height = image.size + if width * height > self.max_resolution: + scale = (width * height / self.max_resolution) ** 0.5 + image = image.resize((int(width / scale), int(height / scale))) + width, height = image.size + if width % 16 != 0 or height % 16 != 0: + image = crop_and_resize(image, height // 16 * 16, width // 16 * 16) + image = v2.functional.to_image(image) + image = v2.functional.to_dtype(image, dtype=torch.float32, scale=True) + image = v2.functional.normalize(image, [0.5], [0.5]) + return image + + + def get_data(self, data_id): + data = { + "model_file": os.path.join(self.base_path, self.model_file[data_id]), + "image": self.read_image(os.path.join(self.base_path, self.image_file[data_id])), + "text": self.text[data_id] + } + return data + + + def __getitem__(self, index): + data = [] + while len(data) < self.loras_per_item: + data_id = torch.randint(0, len(self.model_file), (1,))[0] + data_id = (data_id + index) % len(self.model_file) # For fixed seed. + data.append(self.get_data(data_id)) + return data + + + def __len__(self): + return self.steps_per_epoch \ No newline at end of file diff --git a/lora/merger.py b/lora/merger.py new file mode 100644 index 0000000..af6d4ae --- /dev/null +++ b/lora/merger.py @@ -0,0 +1,61 @@ +import torch + + +class LoraMerger(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.weight_base = torch.nn.Parameter(torch.randn((dim,))) + self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) + self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) + self.weight_out = torch.nn.Parameter(torch.ones((dim,))) + self.bias = torch.nn.Parameter(torch.randn((dim,))) + self.activation = torch.nn.Sigmoid() + self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) + self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) + + def forward(self, base_output, lora_outputs): + norm_base_output = self.norm_base(base_output) + norm_lora_outputs = self.norm_lora(lora_outputs) + gate = self.activation( + norm_base_output * self.weight_base \ + + norm_lora_outputs * self.weight_lora \ + + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias + ) + output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) + return output + + +class LoraPatcher(torch.nn.Module): + def __init__(self, lora_patterns=None): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoraMerger(dim) + self.model_dict = torch.nn.ModuleDict(model_dict) + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, + "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + return lora_patterns + + def forward(self, base_output, lora_outputs, name): + return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) \ No newline at end of file diff --git a/lora/retriever.py b/lora/retriever.py new file mode 100644 index 0000000..d77a951 --- /dev/null +++ b/lora/retriever.py @@ -0,0 +1,149 @@ +import torch +from diffsynth import SDTextEncoder +from diffsynth.models.sd3_text_encoder import SD3TextEncoder1StateDictConverter +from diffsynth.models.sd_text_encoder import CLIPEncoderLayer + + +class LoRALayerBlock(torch.nn.Module): + def __init__(self, L, dim_in): + super().__init__() + self.x = torch.nn.Parameter(torch.randn(1, L, dim_in)) + + def forward(self, lora_A, lora_B): + out = self.x @ lora_A.T @ lora_B.T + return out + + +class LoRAEmbedder(torch.nn.Module): + def __init__(self, lora_patterns=None, L=1, out_dim=2048): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"][0] + model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim) + self.model_dict = torch.nn.ModuleDict(model_dict) + + proj_dict = {} + for lora_pattern in lora_patterns: + layer_type, dim = lora_pattern["type"], lora_pattern["dim"][1] + if layer_type not in proj_dict: + proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim, out_dim) + self.proj_dict = torch.nn.ModuleDict(proj_dict) + + self.lora_patterns = lora_patterns + + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432), + "attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432), + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix], + "type": suffix, + }) + lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix], + "type": suffix, + }) + return lora_patterns + + def forward(self, lora): + lora_emb = [] + for lora_pattern in self.lora_patterns: + name, layer_type = lora_pattern["name"], lora_pattern["type"] + lora_A = lora[name + ".lora_A.default.weight"] + lora_B = lora[name + ".lora_B.default.weight"] + lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B) + lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out) + lora_emb.append(lora_out) + lora_emb = torch.concat(lora_emb, dim=1) + return lora_emb + + +class TextEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=1): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + break + embeds = self.final_layer_norm(embeds) + pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)] + return pooled_embeds + + @staticmethod + def state_dict_converter(): + return SD3TextEncoder1StateDictConverter() + + +class LoRAEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, max_position_embeddings=304, num_encoder_layers=2, encoder_intermediate_size=3072, L=1): + super().__init__() + max_position_embeddings *= L + + # Embedder + self.embedder = LoRAEmbedder(L=L, out_dim=embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, lora): + embeds = self.embedder(lora) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + embeds = self.final_layer_norm(embeds) + embeds = embeds.mean(dim=1) + return embeds diff --git a/lora/test_merger.py b/lora/test_merger.py new file mode 100644 index 0000000..6266230 --- /dev/null +++ b/lora/test_merger.py @@ -0,0 +1,46 @@ +from diffsynth import FluxImagePipeline, ModelManager, load_state_dict +from diffsynth.models.lora import FluxLoRAConverter +from diffsynth.pipelines.flux_image import lets_dance_flux +from lora.dataset import LoraDataset +from lora.merger import LoraPatcher +from lora.utils import load_lora +import torch, os +from accelerate import Accelerator, DistributedDataParallelKwargs +from tqdm import tqdm + + +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) +pipe = FluxImagePipeline.from_model_manager(model_manager) +pipe.enable_auto_lora() + +lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda") +lora_patcher.load_state_dict(load_state_dict("models/lora_merger/epoch-3.safetensors")) + +dataset = LoraDataset("data/lora/models", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=4) + +for seed in range(100): + batch = dataset[0] + num_lora = torch.randint(1, len(batch), (1,))[0] + lora_state_dicts = [ + FluxLoRAConverter.align_to_diffsynth_format(load_lora(batch[i]["model_file"], device="cuda")) for i in range(num_lora) + ] + image = pipe( + prompt=batch[0]["text"], + seed=seed, + ) + image.save(f"data/lora/lora_outputs/image_{seed}_nolora.jpg") + for i in range(num_lora): + image = pipe( + prompt=batch[0]["text"], + lora_state_dicts=[lora_state_dicts[i]], + lora_patcher=lora_patcher, + seed=seed, + ) + image.save(f"data/lora/lora_outputs/image_{seed}_{i}.jpg") + image = pipe( + prompt=batch[0]["text"], + lora_state_dicts=lora_state_dicts, + lora_patcher=lora_patcher, + seed=seed, + ) + image.save(f"data/lora/lora_outputs/image_{seed}_merger.jpg") diff --git a/lora/test_retriever.py b/lora/test_retriever.py new file mode 100644 index 0000000..fc66707 --- /dev/null +++ b/lora/test_retriever.py @@ -0,0 +1,148 @@ +from diffsynth import FluxImagePipeline, ModelManager, load_state_dict +from diffsynth.models.lora import FluxLoRAConverter +from diffsynth.pipelines.flux_image import lets_dance_flux +from lora.dataset import LoraDataset +from lora.retriever import TextEncoder, LoRAEncoder +from lora.merger import LoraPatcher +from lora.utils import load_lora +import torch, os +from accelerate import Accelerator, DistributedDataParallelKwargs +from tqdm import tqdm +from transformers import CLIPTokenizer, CLIPModel +import pandas as pd + + + +class LoRARetrieverTrainingModel(torch.nn.Module): + def __init__(self, pretrained_path): + super().__init__() + + self.text_encoder = TextEncoder().to(torch.bfloat16) + state_dict = load_state_dict("models/FLUX/FLUX.1-dev/text_encoder/model.safetensors") + self.text_encoder.load_state_dict(TextEncoder.state_dict_converter().from_civitai(state_dict)) + self.text_encoder.requires_grad_(False) + self.text_encoder.eval() + + self.lora_encoder = LoRAEncoder().to(torch.bfloat16) + state_dict = load_state_dict(pretrained_path) + self.lora_encoder.load_state_dict(state_dict) + + self.tokenizer = CLIPTokenizer.from_pretrained("diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1") + + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + if device is not None: + self.device = device + if dtype is not None: + self.torch_dtype = dtype + super().to(*args, **kwargs) + return self + + + def forward(self, batch): + text = [data["text"] for data in batch] + input_ids = self.tokenizer( + text, + return_tensors="pt", + padding="max_length", + max_length=77, + truncation=True + ).input_ids.to(self.device) + text_emb = self.text_encoder(input_ids) + text_emb = text_emb / text_emb.norm() + + lora_emb = [] + for data in batch: + lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(data["model_file"], device=self.device)) + lora_emb.append(self.lora_encoder(lora)) + lora_emb = torch.concat(lora_emb) + lora_emb = lora_emb / lora_emb.norm() + + similarity = text_emb @ lora_emb.T + print(similarity) + loss = -torch.log(torch.softmax(similarity, dim=0).diag()) - torch.log(torch.softmax(similarity, dim=1).diag()) + loss = 10 * loss.mean() + return loss + + + def trainable_modules(self): + return self.lora_encoder.parameters() + + @torch.no_grad() + def process_lora_list(self, lora_list): + lora_emb = [] + for lora in tqdm(lora_list): + lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(lora, device="cuda")) + lora_emb.append(self.lora_encoder(lora)) + lora_emb = torch.concat(lora_emb) + lora_emb = lora_emb / lora_emb.norm() + self.lora_emb = lora_emb + self.lora_list = lora_list + + @torch.no_grad() + def retrieve(self, text, k=1): + input_ids = self.tokenizer( + text, + return_tensors="pt", + padding="max_length", + max_length=77, + truncation=True + ).input_ids.to(self.device) + text_emb = self.text_encoder(input_ids) + text_emb = text_emb / text_emb.norm() + + similarity = text_emb @ self.lora_emb.T + topk = torch.topk(similarity, k, dim=1).indices[0] + + lora_list = [] + model_url_list = [] + for lora_id in topk: + print(self.lora_list[lora_id]) + lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(self.lora_list[lora_id], device="cuda")) + lora_list.append(lora) + model_id = self.lora_list[lora_id].split("/")[3:5] + model_url_list.append(f"https://www.modelscope.cn/models/{model_id[0]}/{model_id[1]}") + return lora_list, model_url_list + + + +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) +pipe = FluxImagePipeline.from_model_manager(model_manager) +pipe.enable_auto_lora() + +lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda") +lora_patcher.load_state_dict(load_state_dict("models/lora_merger/epoch-9.safetensors")) + +retriever = LoRARetrieverTrainingModel("models/lora_retriever/epoch-3.safetensors").to(dtype=torch.bfloat16, device="cuda") +retriever.process_lora_list(list(set("data/lora/models/" + i for i in pd.read_csv("data/lora/lora_dataset_1000.csv")["model_file"]))) + +dataset = LoraDataset("data/lora/models", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=1) + +text_list = [] +model_url_list = [] +for seed in range(100): + text = dataset[0][0]["text"] + print(text) + loras, urls = retriever.retrieve(text, k=3) + print(urls) + image = pipe( + prompt=text, + seed=seed, + ) + image.save(f"data/lora/lora_outputs/image_{seed}_top0.jpg") + for i in range(2, 3): + image = pipe( + prompt=text, + lora_state_dicts=loras[:i+1], + lora_patcher=lora_patcher, + seed=seed, + ) + image.save(f"data/lora/lora_outputs/image_{seed}_top{i+1}.jpg") + + text_list.append(text) + model_url_list.append(urls) + df = pd.DataFrame() + df["text"] = text_list + df["models"] = [",".join(i) for i in model_url_list] + df.to_csv("data/lora/lora_outputs/metadata.csv", index=False, encoding="utf-8-sig") \ No newline at end of file diff --git a/lora/train_merger.py b/lora/train_merger.py new file mode 100644 index 0000000..7fbfe41 --- /dev/null +++ b/lora/train_merger.py @@ -0,0 +1,119 @@ +from diffsynth import FluxImagePipeline, ModelManager +from diffsynth.models.lora import FluxLoRAConverter +from diffsynth.pipelines.flux_image import lets_dance_flux +from lora.dataset import LoraDataset +from lora.merger import LoraPatcher +from lora.utils import load_lora +import torch, os +from accelerate import Accelerator, DistributedDataParallelKwargs +from tqdm import tqdm + + + +class LoRAMergerTrainingModel(torch.nn.Module): + def __init__(self): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu", model_id_list=["FLUX.1-dev"]) + self.pipe = FluxImagePipeline.from_model_manager(model_manager) + self.lora_patcher = LoraPatcher() + self.pipe.enable_auto_lora() + self.freeze_parameters() + self.switch_to_training_mode() + self.use_gradient_checkpointing = True + self.state_dict_converter = FluxLoRAConverter.align_to_diffsynth_format + self.device = "cuda" + + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + if device is not None: + self.device = device + if dtype is not None: + self.torch_dtype = dtype + super().to(*args, **kwargs) + return self + + + def switch_to_training_mode(self): + self.pipe.scheduler.set_timesteps(1000, training=True) + + + def freeze_parameters(self): + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + self.lora_patcher.requires_grad_(True) + + + def forward(self, batch): + # Data + text, image = batch[0]["text"], batch[0]["image"].unsqueeze(0) + num_lora = torch.randint(1, len(batch), (1,))[0] + lora_state_dicts = [ + self.state_dict_converter(load_lora(batch[i]["model_file"], device=self.device)) for i in range(num_lora) + ] + lora_alphas = None + + # Prepare input parameters + self.pipe.device = self.device + prompt_emb = self.pipe.encode_prompt(text, positive=True) + latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device)) + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device) + extra_input = self.pipe.prepare_extra_input(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Compute loss + noise_pred = lets_dance_flux( + self.pipe.dit, + hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, + lora_state_dicts=lora_state_dicts, lora_alphas=lora_alphas, lora_patcher=self.lora_patcher, + use_gradient_checkpointing=self.use_gradient_checkpointing + ) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + return loss + + + def trainable_modules(self): + return self.lora_patcher.parameters() + + +class ModelLogger: + def __init__(self, output_path, remove_prefix_in_ckpt=None): + self.output_path = output_path + self.remove_prefix_in_ckpt = remove_prefix_in_ckpt + + + def on_step_end(self, loss): + pass + + + def on_epoch_end(self, accelerator, model, epoch_id): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state_dict = accelerator.unwrap_model(model).lora_patcher.state_dict() + os.makedirs(self.output_path, exist_ok=True) + path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") + accelerator.save(state_dict, path, safe_serialization=True) + + +if __name__ == '__main__': + model = LoRAMergerTrainingModel() + dataset = LoraDataset("data/lora/models/", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=4) + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1, collate_fn=lambda x: x[0]) + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=1e-4) + model_logger = ModelLogger("models/lora_merger") + accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + + for epoch_id in range(1000000): + for data in tqdm(dataloader): + with accelerator.accumulate(model): + optimizer.zero_grad() + loss = model(data) + accelerator.backward(loss) + optimizer.step() + model_logger.on_epoch_end(accelerator, model, epoch_id) diff --git a/lora/train_retriever.py b/lora/train_retriever.py new file mode 100644 index 0000000..dd96890 --- /dev/null +++ b/lora/train_retriever.py @@ -0,0 +1,105 @@ +from diffsynth import FluxImagePipeline, ModelManager, load_state_dict +from diffsynth.models.lora import FluxLoRAConverter +from diffsynth.pipelines.flux_image import lets_dance_flux +from lora.dataset import LoraDataset +from lora.retriever import TextEncoder, LoRAEncoder +from lora.utils import load_lora +import torch, os +from accelerate import Accelerator, DistributedDataParallelKwargs +from tqdm import tqdm +from transformers import CLIPTokenizer, CLIPModel + + + +class LoRARetrieverTrainingModel(torch.nn.Module): + def __init__(self): + super().__init__() + + self.text_encoder = TextEncoder().to(torch.bfloat16) + state_dict = load_state_dict("models/FLUX/FLUX.1-dev/text_encoder/model.safetensors") + self.text_encoder.load_state_dict(TextEncoder.state_dict_converter().from_civitai(state_dict)) + self.text_encoder.requires_grad_(False) + self.text_encoder.eval() + + self.lora_encoder = LoRAEncoder().to(torch.bfloat16) + + self.tokenizer = CLIPTokenizer.from_pretrained("diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1") + + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + if device is not None: + self.device = device + if dtype is not None: + self.torch_dtype = dtype + super().to(*args, **kwargs) + return self + + + def forward(self, batch): + text = [data["text"] for data in batch] + input_ids = self.tokenizer( + text, + return_tensors="pt", + padding="max_length", + max_length=77, + truncation=True + ).input_ids.to(self.device) + text_emb = self.text_encoder(input_ids) + text_emb = text_emb / text_emb.norm() + + lora_emb = [] + for data in batch: + lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(data["model_file"], device=self.device)) + lora_emb.append(self.lora_encoder(lora)) + lora_emb = torch.concat(lora_emb) + lora_emb = lora_emb / lora_emb.norm() + + similarity = text_emb @ lora_emb.T + print(similarity) + loss = -torch.log(torch.softmax(similarity, dim=0).diag()) - torch.log(torch.softmax(similarity, dim=1).diag()) + loss = 10 * loss.mean() + return loss + + + def trainable_modules(self): + return self.lora_encoder.parameters() + + +class ModelLogger: + def __init__(self, output_path, remove_prefix_in_ckpt=None): + self.output_path = output_path + self.remove_prefix_in_ckpt = remove_prefix_in_ckpt + + + def on_step_end(self, loss): + pass + + + def on_epoch_end(self, accelerator, model, epoch_id): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state_dict = accelerator.unwrap_model(model).lora_encoder.state_dict() + os.makedirs(self.output_path, exist_ok=True) + path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") + accelerator.save(state_dict, path, safe_serialization=True) + + +if __name__ == '__main__': + model = LoRARetrieverTrainingModel() + dataset = LoraDataset("data/lora/models/", "data/lora/lora_dataset_1000.csv", steps_per_epoch=100, loras_per_item=32) + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1, collate_fn=lambda x: x[0]) + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=1e-4) + model_logger = ModelLogger("models/lora_retriever") + accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + + for epoch_id in range(1000000): + for data in tqdm(dataloader): + with accelerator.accumulate(model): + optimizer.zero_grad() + loss = model(data) + accelerator.backward(loss) + optimizer.step() + print(loss) + model_logger.on_epoch_end(accelerator, model, epoch_id) diff --git a/lora/utils.py b/lora/utils.py new file mode 100644 index 0000000..3191043 --- /dev/null +++ b/lora/utils.py @@ -0,0 +1,12 @@ +from diffsynth import load_state_dict +import math, torch + + +def load_lora(file_path, device): + sd = load_state_dict(file_path, torch_dtype=torch.bfloat16, device=device) + scale = math.sqrt(sd["lora_unet_single_blocks_9_modulation_lin.alpha"] / sd["lora_unet_single_blocks_9_modulation_lin.lora_down.weight"].shape[0]) + if scale != 1: + sd = {i: sd[i] * scale for i in sd} + return sd + + diff --git a/scripts/data_process.py b/scripts/data_process.py deleted file mode 100644 index 44340e9..0000000 --- a/scripts/data_process.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch, os, dashscope -import pandas as pd -from tqdm import tqdm -from diffsynth import load_state_dict, hash_state_dict_keys - - -def search_for_model_file(path, allow_file_extensions=(".safetensors",)): - for file_name in os.listdir(path): - for file_extension in allow_file_extensions: - if file_name.endswith(file_extension): - return os.path.join(path, file_name) - - -def search_for_cover_images(path, allow_file_extensions=(".png", ".jpg", ".jpeg")): - image_files = [] - for file_name in os.listdir(path): - for file_extension in allow_file_extensions: - if file_name.endswith(file_extension): - image_files.append(os.path.join(path, file_name)) - break - return image_files - - -def search_for_lora_data(path): - model_file = search_for_model_file(path) - if "_cover_images_" not in os.listdir(path): - return None - image_files = search_for_cover_images(os.path.join(path, "_cover_images_")) - if model_file is None or len(image_files) == 0: - return None - state_dict = load_state_dict(model_file) - if hash_state_dict_keys(state_dict, with_shape=False) != "52544ae3076666228978b738fbb8b086": - return None - return model_file, image_files - - -def image_to_text(images=[], prompt="", system_prompt=None): - dashscope.api_key = "xxxxx" # TODO - messages = [] - if system_prompt is not None: - messages.append({"role": "system", "content": system_prompt}) - if not isinstance(images, list): - images = [images] - messages.append({"role": "user", "content": [{"text": prompt}] + [{"image": image} for image in images]}) - response = dashscope.MultiModalConversation.call(model="qwen-vl-max-latest", messages=messages) - response = response["output"]["choices"][0]["message"]["content"][0]["text"] - return response - - -qwen_i2t_prompt = ''' -You are a professional image captioner. -Generate a caption according to the image so that another image generation model can generate the image via the caption. Just return the string description, do not return anything else. -'''.strip() - - -def data_to_csv(model_file_list, image_file_list, text_list, save_path): - data_df = pd.DataFrame() - data_df["model_file"] = model_file_list - data_df["image_file"] = image_file_list - data_df["text"] = text_list - data_df.to_csv(save_path, index=False, encoding="utf-8-sig") - - -base_path = "/data/zhiwen/LoRA-Fusion/models/FLUXLoRA" - -model_file_list = [] -image_file_list = [] -text_list = [] - -for lora_name in tqdm(os.listdir(base_path)): - lora_folder_path = os.path.join(base_path, lora_name) - if os.path.isdir(lora_folder_path): - data = search_for_lora_data(lora_folder_path) - if data is not None: - model_file, image_files = data - for image_file in image_files: - try: - text = image_to_text(image_file, prompt=qwen_i2t_prompt) - except: - continue - model_file_list.append(model_file) - image_file_list.append(image_file) - text_list.append(text) - data_to_csv(model_file_list, image_file_list, text_list, "data/loras.csv") - diff --git a/scripts/test.py b/scripts/test.py deleted file mode 100644 index 8edc327..0000000 --- a/scripts/test.py +++ /dev/null @@ -1,166 +0,0 @@ -import torch, shutil, os -from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict -from diffsynth.models.lora import LoRAFromCivitai, FluxLoRAConverter -import pandas as pd -import torch -import pandas as pd -from PIL import Image -import lightning as pl -from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict -from diffsynth.models.lora import LoRAFromCivitai, FluxLoRAConverter -from diffsynth.data.video import crop_and_resize -from diffsynth.pipelines.flux_image import lets_dance_flux -from torchvision.transforms import v2 - - -baseline = "trained" - - -class LoraMerger(torch.nn.Module): - def __init__(self, dim): - super().__init__() - self.weight_base = torch.nn.Parameter(torch.randn((dim,))) - self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) - self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) - self.weight_out = torch.nn.Parameter(torch.ones((dim,))) - self.bias = torch.nn.Parameter(torch.randn((dim,))) - self.activation = torch.nn.Sigmoid() - self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) - self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) - - def forward(self, base_output, lora_outputs): - global baseline - if baseline == "nolora": - output = base_output - elif baseline == "lora1": - output = base_output + lora_outputs[0] - elif baseline == "lora2": - output = base_output + lora_outputs[1] - elif baseline == "alllora": - output = base_output + lora_outputs.sum(dim=0) - else: - norm_base_output = self.norm_base(base_output) - norm_lora_outputs = self.norm_lora(lora_outputs) - gate = self.activation( - norm_base_output * self.weight_base \ - + norm_lora_outputs * self.weight_lora \ - + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias - ) - output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) - return output - - - -class LoraPatcher(torch.nn.Module): - def __init__(self, lora_patterns=None): - super().__init__() - if lora_patterns is None: - lora_patterns = self.default_lora_patterns() - model_dict = {} - for lora_pattern in lora_patterns: - name, dim = lora_pattern["name"], lora_pattern["dim"] - model_dict[name.replace(".", "___")] = LoraMerger(dim) - self.model_dict = torch.nn.ModuleDict(model_dict) - - def default_lora_patterns(self): - lora_patterns = [] - lora_dict = { - "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, - "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, - } - for i in range(19): - for suffix in lora_dict: - lora_patterns.append({ - "name": f"blocks.{i}.{suffix}", - "dim": lora_dict[suffix] - }) - lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} - for i in range(38): - for suffix in lora_dict: - lora_patterns.append({ - "name": f"single_blocks.{i}.{suffix}", - "dim": lora_dict[suffix] - }) - return lora_patterns - - def forward(self, base_output, lora_outputs, name): - return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) - - -class LoraDataset(torch.utils.data.Dataset): - def __init__(self, metadata_path, steps_per_epoch=1000): - data_df = pd.read_csv(metadata_path) - self.model_file = data_df["model_file"].tolist() - self.image_file = data_df["image_file"].tolist() - self.text = data_df["text"].tolist() - self.max_resolution = 1920 * 1080 - self.steps_per_epoch = steps_per_epoch - - - def read_image(self, image_file): - image = Image.open(image_file) - width, height = image.size - if width * height > self.max_resolution: - scale = (width * height / self.max_resolution) ** 0.5 - image = image.resize((int(width / scale), int(height / scale))) - width, height = image.size - if width % 16 != 0 or height % 16 != 0: - image = crop_and_resize(image, height // 16 * 16, width // 16 * 16) - image = v2.functional.to_image(image) - image = v2.functional.to_dtype(image, dtype=torch.float32, scale=True) - image = v2.functional.normalize(image, [0.5], [0.5]) - return image - - - def __getitem__(self, index): - data_id = torch.randint(0, len(self.model_file), (1,))[0] - data_id = (data_id + index) % len(self.model_file) # For fixed seed. - data_id_extra = torch.randint(0, len(self.model_file), (1,))[0] - return { - "model_file": self.model_file[data_id], - "model_file_extra": self.model_file[data_id_extra], - "image": self.read_image(self.image_file[data_id]), - "text": self.text[data_id] - } - - - def __len__(self): - return self.steps_per_epoch - - - - -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") -model_manager.load_models([ - "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", - "models/FLUX/FLUX.1-dev/text_encoder_2", - "models/FLUX/FLUX.1-dev/ae.safetensors", - "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" -]) -pipe = FluxImagePipeline.from_model_manager(model_manager) -pipe.enable_auto_lora() - - -lora_alpahs = [1, 1] -lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda") -lora_patcher.load_state_dict(load_state_dict("models/lightning_logs/version_13/checkpoints/epoch=2-step=1500.ckpt")) - -dataset = LoraDataset("data/loras_picked.csv") - -for seed in range(100): - data = dataset[0] - lora_state_dicts = [ - FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(data["model_file"], torch_dtype=torch.bfloat16, device="cuda")), - FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(data["model_file_extra"], torch_dtype=torch.bfloat16, device="cuda")), - ] - lora_alpahs = [1, 1] - for pattern in ["nolora", "lora1", "lora2", "alllora", "loramerger"]: - baseline = pattern - image = pipe( - prompt=data["text"], - lora_state_dicts=lora_state_dicts, - lora_alpahs=lora_alpahs, - lora_patcher=lora_patcher, - seed=seed, - ) - image.save(f"data/lora_outputs/image_{seed}_{pattern}.jpg") \ No newline at end of file diff --git a/scripts/train.py b/scripts/train.py deleted file mode 100644 index 2733851..0000000 --- a/scripts/train.py +++ /dev/null @@ -1,207 +0,0 @@ -import torch -import pandas as pd -from PIL import Image -import lightning as pl -from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict -from diffsynth.models.lora import LoRAFromCivitai, FluxLoRAConverter -from diffsynth.data.video import crop_and_resize -from diffsynth.pipelines.flux_image import lets_dance_flux -from torchvision.transforms import v2 - - - -class LoraMerger(torch.nn.Module): - def __init__(self, dim): - super().__init__() - self.weight_base = torch.nn.Parameter(torch.randn((dim,))) - self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) - self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) - self.weight_out = torch.nn.Parameter(torch.ones((dim,))) - self.bias = torch.nn.Parameter(torch.randn((dim,))) - self.activation = torch.nn.Sigmoid() - self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) - self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) - - def forward(self, base_output, lora_outputs): - norm_base_output = self.norm_base(base_output) - norm_lora_outputs = self.norm_lora(lora_outputs) - gate = self.activation( - norm_base_output * self.weight_base \ - + norm_lora_outputs * self.weight_lora \ - + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias - ) - output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) - return output - - - -class LoraPatcher(torch.nn.Module): - def __init__(self, lora_patterns=None): - super().__init__() - if lora_patterns is None: - lora_patterns = self.default_lora_patterns() - model_dict = {} - for lora_pattern in lora_patterns: - name, dim = lora_pattern["name"], lora_pattern["dim"] - model_dict[name.replace(".", "___")] = LoraMerger(dim) - self.model_dict = torch.nn.ModuleDict(model_dict) - - def default_lora_patterns(self): - lora_patterns = [] - lora_dict = { - "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, - "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, - } - for i in range(19): - for suffix in lora_dict: - lora_patterns.append({ - "name": f"blocks.{i}.{suffix}", - "dim": lora_dict[suffix] - }) - lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} - for i in range(38): - for suffix in lora_dict: - lora_patterns.append({ - "name": f"single_blocks.{i}.{suffix}", - "dim": lora_dict[suffix] - }) - return lora_patterns - - def forward(self, base_output, lora_outputs, name): - return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) - - - -class LoraDataset(torch.utils.data.Dataset): - def __init__(self, metadata_path, steps_per_epoch=1000): - data_df = pd.read_csv(metadata_path) - self.model_file = data_df["model_file"].tolist() - self.image_file = data_df["image_file"].tolist() - self.text = data_df["text"].tolist() - self.max_resolution = 1920 * 1080 - self.steps_per_epoch = steps_per_epoch - - - def read_image(self, image_file): - image = Image.open(image_file) - width, height = image.size - if width * height > self.max_resolution: - scale = (width * height / self.max_resolution) ** 0.5 - image = image.resize((int(width / scale), int(height / scale))) - width, height = image.size - if width % 16 != 0 or height % 16 != 0: - image = crop_and_resize(image, height // 16 * 16, width // 16 * 16) - image = v2.functional.to_image(image) - image = v2.functional.to_dtype(image, dtype=torch.float32, scale=True) - image = v2.functional.normalize(image, [0.5], [0.5]) - return image - - - def __getitem__(self, index): - data_id = torch.randint(0, len(self.model_file), (1,))[0] - data_id = (data_id + index) % len(self.model_file) # For fixed seed. - data_id_extra = torch.randint(0, len(self.model_file), (1,))[0] - return { - "model_file": self.model_file[data_id], - "model_file_extra": self.model_file[data_id_extra], - "image": self.read_image(self.image_file[data_id]), - "text": self.text[data_id] - } - - - def __len__(self): - return self.steps_per_epoch - - - -class LightningModel(pl.LightningModule): - def __init__( - self, - learning_rate=1e-4, - use_gradient_checkpointing=True, - state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format, - ): - super().__init__() - model_manager = ModelManager(torch_dtype=torch.bfloat16, device=self.device, model_id_list=["FLUX.1-dev"]) - self.pipe = FluxImagePipeline.from_model_manager(model_manager) - self.lora_patcher = LoraPatcher() - self.pipe.enable_auto_lora() - self.pipe.scheduler.set_timesteps(1000, training=True) - self.freeze_parameters() - # Set parameters - self.learning_rate = learning_rate - self.use_gradient_checkpointing = use_gradient_checkpointing - self.state_dict_converter = state_dict_converter - - - def freeze_parameters(self): - # Freeze parameters - self.pipe.requires_grad_(False) - self.pipe.eval() - self.pipe.denoising_model().train() - - - def training_step(self, batch, batch_idx): - # Data - text, image = batch["text"], batch["image"] - lora_state_dicts = [ - self.state_dict_converter(load_state_dict(batch["model_file"][0], torch_dtype=torch.bfloat16, device=self.device)), - self.state_dict_converter(load_state_dict(batch["model_file_extra"][0], torch_dtype=torch.bfloat16, device=self.device)), - ] - lora_alpahs = [1, 1] - - # Prepare input parameters - self.pipe.device = self.device - prompt_emb = self.pipe.encode_prompt(text, positive=True) - if "latents" in batch: - latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device) - else: - latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device)) - noise = torch.randn_like(latents) - timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) - timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device) - extra_input = self.pipe.prepare_extra_input(latents) - noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) - training_target = self.pipe.scheduler.training_target(latents, noise, timestep) - - # Compute loss - noise_pred = lets_dance_flux( - self.pipe.dit, - hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, - lora_state_dicts=lora_state_dicts, lora_alpahs=lora_alpahs, lora_patcher=self.lora_patcher, - use_gradient_checkpointing=self.use_gradient_checkpointing - ) - loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) - loss = loss * self.pipe.scheduler.training_weight(timestep) - - # Record log - self.log("train_loss", loss, prog_bar=True) - return loss - - - def configure_optimizers(self): - trainable_modules = filter(lambda p: p.requires_grad, self.lora_patcher.parameters()) - optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) - return optimizer - - - def on_save_checkpoint(self, checkpoint): - checkpoint.clear() - checkpoint.update(self.lora_patcher.state_dict()) - - -if __name__ == '__main__': - model = LightningModel(learning_rate=1e-4) - dataset = LoraDataset("data/loras.csv", steps_per_epoch=500) - train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1) - trainer = pl.Trainer( - max_epochs=100000, - accelerator="gpu", - devices="auto", - precision="bf16", - strategy="auto", - default_root_dir="./models", - accumulate_grad_batches=1, - callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], - ) - trainer.fit(model=model, train_dataloaders=train_loader)