lora merger

This commit is contained in:
Artiprocher
2025-04-21 15:48:25 +08:00
parent 04260801a2
commit 44da204dbd
7 changed files with 516 additions and 30 deletions

View File

@@ -62,25 +62,26 @@ def load_state_dict_from_folder(file_path, torch_dtype=None):
return state_dict
def load_state_dict(file_path, torch_dtype=None):
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
state_dict[k] = state_dict[k].to(device)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
state_dict = torch.load(file_path, map_location=device, weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):

View File

@@ -401,7 +401,8 @@ class FluxImagePipeline(BasePipeline):
progress_bar_cmd=tqdm,
progress_bar_st=None,
lora_state_dicts=[],
lora_alpahs=[]
lora_alpahs=[],
lora_patcher=None,
):
height, width = self.check_resize_height_width(height, width)
@@ -443,6 +444,7 @@ class FluxImagePipeline(BasePipeline):
hidden_states=latents, timestep=timestep,
lora_state_dicts=lora_state_dicts,
lora_alpahs = lora_alpahs,
lora_patcher=lora_patcher,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
)
noise_pred_posi = self.control_noise_via_local_prompts(
@@ -462,6 +464,7 @@ class FluxImagePipeline(BasePipeline):
hidden_states=latents, timestep=timestep,
lora_state_dicts=lora_state_dicts,
lora_alpahs = lora_alpahs,
lora_patcher=lora_patcher,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
@@ -544,6 +547,7 @@ def lets_dance_flux(
entity_masks=None,
ipadapter_kwargs_list={},
tea_cache: TeaCache = None,
use_gradient_checkpointing=False,
**kwargs
):
@@ -610,6 +614,11 @@ def lets_dance_flux(
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
attention_mask = None
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
# TeaCache
if tea_cache is not None:
@@ -622,15 +631,22 @@ def lets_dance_flux(
else:
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None),
**kwargs
)
if use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None), **kwargs,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None),
**kwargs
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
@@ -639,15 +655,22 @@ def lets_dance_flux(
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
**kwargs
)
if use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), **kwargs,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
**kwargs
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]

View File

@@ -71,15 +71,16 @@ class AutoWrappedLinear(torch.nn.Linear):
return torch.nn.functional.linear(x, weight, bias)
class AutoLoRALinear(torch.nn.Linear):
def __init__(self, name='', in_features=1, out_features=2, bias = True, device=None, dtype=None):
def __init__(self, name='', in_features=1, out_features=2, bias=True, device=None, dtype=None):
super().__init__(in_features, out_features, bias, device, dtype)
self.name = name
def forward(self, x, lora_state_dicts=[], lora_alpahs=[1.0,1.0], **kwargs):
def forward(self, x, lora_state_dicts=[], lora_alpahs=[1.0,1.0], lora_patcher=None, **kwargs):
out = torch.nn.functional.linear(x, self.weight, self.bias)
lora_a_name = f'{self.name}.lora_A.weight'
lora_b_name = f'{self.name}.lora_B.weight'
lora_a_name = f'{self.name}.lora_A.default.weight'
lora_b_name = f'{self.name}.lora_B.default.weight'
lora_output = []
for i, lora_state_dict in enumerate(lora_state_dicts):
if lora_state_dict is None:
break
@@ -87,7 +88,10 @@ class AutoLoRALinear(torch.nn.Linear):
lora_A = lora_state_dict[lora_a_name].to(dtype=self.weight.dtype,device=self.weight.device)
lora_B = lora_state_dict[lora_b_name].to(dtype=self.weight.dtype,device=self.weight.device)
out_lora = x @ lora_A.T @ lora_B.T
out = out + out_lora * lora_alpahs[i]
lora_output.append(out_lora)
if len(lora_output) > 0:
lora_output = torch.stack(lora_output)
out = lora_patcher(out, lora_output, self.name)
return out
def enable_auto_lora(model:torch.nn.Module, module_map: dict, name_prefix=''):

85
scripts/data_process.py Normal file
View File

@@ -0,0 +1,85 @@
import torch, os, dashscope
import pandas as pd
from tqdm import tqdm
from diffsynth import load_state_dict, hash_state_dict_keys
def search_for_model_file(path, allow_file_extensions=(".safetensors",)):
for file_name in os.listdir(path):
for file_extension in allow_file_extensions:
if file_name.endswith(file_extension):
return os.path.join(path, file_name)
def search_for_cover_images(path, allow_file_extensions=(".png", ".jpg", ".jpeg")):
image_files = []
for file_name in os.listdir(path):
for file_extension in allow_file_extensions:
if file_name.endswith(file_extension):
image_files.append(os.path.join(path, file_name))
break
return image_files
def search_for_lora_data(path):
model_file = search_for_model_file(path)
if "_cover_images_" not in os.listdir(path):
return None
image_files = search_for_cover_images(os.path.join(path, "_cover_images_"))
if model_file is None or len(image_files) == 0:
return None
state_dict = load_state_dict(model_file)
if hash_state_dict_keys(state_dict, with_shape=False) != "52544ae3076666228978b738fbb8b086":
return None
return model_file, image_files
def image_to_text(images=[], prompt="", system_prompt=None):
dashscope.api_key = "xxxxx" # TODO
messages = []
if system_prompt is not None:
messages.append({"role": "system", "content": system_prompt})
if not isinstance(images, list):
images = [images]
messages.append({"role": "user", "content": [{"text": prompt}] + [{"image": image} for image in images]})
response = dashscope.MultiModalConversation.call(model="qwen-vl-max-latest", messages=messages)
response = response["output"]["choices"][0]["message"]["content"][0]["text"]
return response
qwen_i2t_prompt = '''
You are a professional image captioner.
Generate a caption according to the image so that another image generation model can generate the image via the caption. Just return the string description, do not return anything else.
'''.strip()
def data_to_csv(model_file_list, image_file_list, text_list, save_path):
data_df = pd.DataFrame()
data_df["model_file"] = model_file_list
data_df["image_file"] = image_file_list
data_df["text"] = text_list
data_df.to_csv(save_path, index=False, encoding="utf-8-sig")
base_path = "/data/zhiwen/LoRA-Fusion/models/FLUXLoRA"
model_file_list = []
image_file_list = []
text_list = []
for lora_name in tqdm(os.listdir(base_path)):
lora_folder_path = os.path.join(base_path, lora_name)
if os.path.isdir(lora_folder_path):
data = search_for_lora_data(lora_folder_path)
if data is not None:
model_file, image_files = data
for image_file in image_files:
try:
text = image_to_text(image_file, prompt=qwen_i2t_prompt)
except:
continue
model_file_list.append(model_file)
image_file_list.append(image_file)
text_list.append(text)
data_to_csv(model_file_list, image_file_list, text_list, "data/loras.csv")

166
scripts/test.py Normal file
View File

@@ -0,0 +1,166 @@
import torch, shutil, os
from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict
from diffsynth.models.lora import LoRAFromCivitai, FluxLoRAConverter
import pandas as pd
import torch
import pandas as pd
from PIL import Image
import lightning as pl
from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict
from diffsynth.models.lora import LoRAFromCivitai, FluxLoRAConverter
from diffsynth.data.video import crop_and_resize
from diffsynth.pipelines.flux_image import lets_dance_flux
from torchvision.transforms import v2
baseline = "trained"
class LoraMerger(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
self.bias = torch.nn.Parameter(torch.randn((dim,)))
self.activation = torch.nn.Sigmoid()
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
def forward(self, base_output, lora_outputs):
global baseline
if baseline == "nolora":
output = base_output
elif baseline == "lora1":
output = base_output + lora_outputs[0]
elif baseline == "lora2":
output = base_output + lora_outputs[1]
elif baseline == "alllora":
output = base_output + lora_outputs.sum(dim=0)
else:
norm_base_output = self.norm_base(base_output)
norm_lora_outputs = self.norm_lora(lora_outputs)
gate = self.activation(
norm_base_output * self.weight_base \
+ norm_lora_outputs * self.weight_lora \
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
)
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
return output
class LoraPatcher(torch.nn.Module):
def __init__(self, lora_patterns=None):
super().__init__()
if lora_patterns is None:
lora_patterns = self.default_lora_patterns()
model_dict = {}
for lora_pattern in lora_patterns:
name, dim = lora_pattern["name"], lora_pattern["dim"]
model_dict[name.replace(".", "___")] = LoraMerger(dim)
self.model_dict = torch.nn.ModuleDict(model_dict)
def default_lora_patterns(self):
lora_patterns = []
lora_dict = {
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
}
for i in range(19):
for suffix in lora_dict:
lora_patterns.append({
"name": f"blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
for i in range(38):
for suffix in lora_dict:
lora_patterns.append({
"name": f"single_blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
return lora_patterns
def forward(self, base_output, lora_outputs, name):
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
class LoraDataset(torch.utils.data.Dataset):
def __init__(self, metadata_path, steps_per_epoch=1000):
data_df = pd.read_csv(metadata_path)
self.model_file = data_df["model_file"].tolist()
self.image_file = data_df["image_file"].tolist()
self.text = data_df["text"].tolist()
self.max_resolution = 1920 * 1080
self.steps_per_epoch = steps_per_epoch
def read_image(self, image_file):
image = Image.open(image_file)
width, height = image.size
if width * height > self.max_resolution:
scale = (width * height / self.max_resolution) ** 0.5
image = image.resize((int(width / scale), int(height / scale)))
width, height = image.size
if width % 16 != 0 or height % 16 != 0:
image = crop_and_resize(image, height // 16 * 16, width // 16 * 16)
image = v2.functional.to_image(image)
image = v2.functional.to_dtype(image, dtype=torch.float32, scale=True)
image = v2.functional.normalize(image, [0.5], [0.5])
return image
def __getitem__(self, index):
data_id = torch.randint(0, len(self.model_file), (1,))[0]
data_id = (data_id + index) % len(self.model_file) # For fixed seed.
data_id_extra = torch.randint(0, len(self.model_file), (1,))[0]
return {
"model_file": self.model_file[data_id],
"model_file_extra": self.model_file[data_id_extra],
"image": self.read_image(self.image_file[data_id]),
"text": self.text[data_id]
}
def __len__(self):
return self.steps_per_epoch
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
pipe.enable_auto_lora()
lora_alpahs = [1, 1]
lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda")
lora_patcher.load_state_dict(load_state_dict("models/lightning_logs/version_13/checkpoints/epoch=2-step=1500.ckpt"))
dataset = LoraDataset("data/loras_picked.csv")
for seed in range(100):
data = dataset[0]
lora_state_dicts = [
FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(data["model_file"], torch_dtype=torch.bfloat16, device="cuda")),
FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(data["model_file_extra"], torch_dtype=torch.bfloat16, device="cuda")),
]
lora_alpahs = [1, 1]
for pattern in ["nolora", "lora1", "lora2", "alllora", "loramerger"]:
baseline = pattern
image = pipe(
prompt=data["text"],
lora_state_dicts=lora_state_dicts,
lora_alpahs=lora_alpahs,
lora_patcher=lora_patcher,
seed=seed,
)
image.save(f"data/lora_outputs/image_{seed}_{pattern}.jpg")

207
scripts/train.py Normal file
View File

@@ -0,0 +1,207 @@
import torch
import pandas as pd
from PIL import Image
import lightning as pl
from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict
from diffsynth.models.lora import LoRAFromCivitai, FluxLoRAConverter
from diffsynth.data.video import crop_and_resize
from diffsynth.pipelines.flux_image import lets_dance_flux
from torchvision.transforms import v2
class LoraMerger(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
self.bias = torch.nn.Parameter(torch.randn((dim,)))
self.activation = torch.nn.Sigmoid()
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
def forward(self, base_output, lora_outputs):
norm_base_output = self.norm_base(base_output)
norm_lora_outputs = self.norm_lora(lora_outputs)
gate = self.activation(
norm_base_output * self.weight_base \
+ norm_lora_outputs * self.weight_lora \
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
)
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
return output
class LoraPatcher(torch.nn.Module):
def __init__(self, lora_patterns=None):
super().__init__()
if lora_patterns is None:
lora_patterns = self.default_lora_patterns()
model_dict = {}
for lora_pattern in lora_patterns:
name, dim = lora_pattern["name"], lora_pattern["dim"]
model_dict[name.replace(".", "___")] = LoraMerger(dim)
self.model_dict = torch.nn.ModuleDict(model_dict)
def default_lora_patterns(self):
lora_patterns = []
lora_dict = {
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
}
for i in range(19):
for suffix in lora_dict:
lora_patterns.append({
"name": f"blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
for i in range(38):
for suffix in lora_dict:
lora_patterns.append({
"name": f"single_blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
return lora_patterns
def forward(self, base_output, lora_outputs, name):
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
class LoraDataset(torch.utils.data.Dataset):
def __init__(self, metadata_path, steps_per_epoch=1000):
data_df = pd.read_csv(metadata_path)
self.model_file = data_df["model_file"].tolist()
self.image_file = data_df["image_file"].tolist()
self.text = data_df["text"].tolist()
self.max_resolution = 1920 * 1080
self.steps_per_epoch = steps_per_epoch
def read_image(self, image_file):
image = Image.open(image_file)
width, height = image.size
if width * height > self.max_resolution:
scale = (width * height / self.max_resolution) ** 0.5
image = image.resize((int(width / scale), int(height / scale)))
width, height = image.size
if width % 16 != 0 or height % 16 != 0:
image = crop_and_resize(image, height // 16 * 16, width // 16 * 16)
image = v2.functional.to_image(image)
image = v2.functional.to_dtype(image, dtype=torch.float32, scale=True)
image = v2.functional.normalize(image, [0.5], [0.5])
return image
def __getitem__(self, index):
data_id = torch.randint(0, len(self.model_file), (1,))[0]
data_id = (data_id + index) % len(self.model_file) # For fixed seed.
data_id_extra = torch.randint(0, len(self.model_file), (1,))[0]
return {
"model_file": self.model_file[data_id],
"model_file_extra": self.model_file[data_id_extra],
"image": self.read_image(self.image_file[data_id]),
"text": self.text[data_id]
}
def __len__(self):
return self.steps_per_epoch
class LightningModel(pl.LightningModule):
def __init__(
self,
learning_rate=1e-4,
use_gradient_checkpointing=True,
state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format,
):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device=self.device, model_id_list=["FLUX.1-dev"])
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
self.lora_patcher = LoraPatcher()
self.pipe.enable_auto_lora()
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
# Set parameters
self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing
self.state_dict_converter = state_dict_converter
def freeze_parameters(self):
# Freeze parameters
self.pipe.requires_grad_(False)
self.pipe.eval()
self.pipe.denoising_model().train()
def training_step(self, batch, batch_idx):
# Data
text, image = batch["text"], batch["image"]
lora_state_dicts = [
self.state_dict_converter(load_state_dict(batch["model_file"][0], torch_dtype=torch.bfloat16, device=self.device)),
self.state_dict_converter(load_state_dict(batch["model_file_extra"][0], torch_dtype=torch.bfloat16, device=self.device)),
]
lora_alpahs = [1, 1]
# Prepare input parameters
self.pipe.device = self.device
prompt_emb = self.pipe.encode_prompt(text, positive=True)
if "latents" in batch:
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
else:
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
noise_pred = lets_dance_flux(
self.pipe.dit,
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
lora_state_dicts=lora_state_dicts, lora_alpahs=lora_alpahs, lora_patcher=self.lora_patcher,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.lora_patcher.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
checkpoint.update(self.lora_patcher.state_dict())
if __name__ == '__main__':
model = LightningModel(learning_rate=1e-4)
dataset = LoraDataset("data/loras.csv", steps_per_epoch=500)
train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1)
trainer = pl.Trainer(
max_epochs=100000,
accelerator="gpu",
devices="auto",
precision="bf16",
strategy="auto",
default_root_dir="./models",
accumulate_grad_batches=1,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
)
trainer.fit(model=model, train_dataloaders=train_loader)