mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support auto detact lora target modules
This commit is contained in:
@@ -10,7 +10,7 @@ class ModelLogger:
|
||||
self.num_steps = 0
|
||||
|
||||
|
||||
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
||||
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
|
||||
self.num_steps += 1
|
||||
if save_steps is not None and self.num_steps % save_steps == 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
@@ -40,7 +40,7 @@ def launch_training_task(
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_step_end(accelerator, model, save_steps)
|
||||
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
||||
scheduler.step()
|
||||
if save_steps is None:
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
|
||||
@@ -150,7 +150,44 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
|
||||
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
|
||||
|
||||
|
||||
def auto_detect_lora_target_modules(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
search_for_linear=False,
|
||||
linear_detector=lambda x: min(x.weight.shape) >= 512,
|
||||
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
|
||||
name_prefix="",
|
||||
):
|
||||
lora_target_modules = []
|
||||
if search_for_linear:
|
||||
for name, module in model.named_modules():
|
||||
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||
if isinstance(module, torch.nn.Linear) and linear_detector(module):
|
||||
lora_target_modules.append(module_name)
|
||||
else:
|
||||
for name, module in model.named_children():
|
||||
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||
lora_target_modules += self.auto_detect_lora_target_modules(
|
||||
module,
|
||||
search_for_linear=block_list_detector(module),
|
||||
linear_detector=linear_detector,
|
||||
block_list_detector=block_list_detector,
|
||||
name_prefix=module_name,
|
||||
)
|
||||
return lora_target_modules
|
||||
|
||||
|
||||
def parse_lora_target_modules(self, model, lora_target_modules):
|
||||
if lora_target_modules == "":
|
||||
print("No LoRA target modules specified. The framework will automatically search for them.")
|
||||
lora_target_modules = self.auto_detect_lora_target_modules(model)
|
||||
print(f"LoRA will be patched at {lora_target_modules}.")
|
||||
else:
|
||||
lora_target_modules = lora_target_modules.split(",")
|
||||
return lora_target_modules
|
||||
|
||||
|
||||
def switch_pipe_to_training_mode(
|
||||
self,
|
||||
pipe,
|
||||
@@ -180,7 +217,7 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
return
|
||||
model = self.add_lora_to_model(
|
||||
getattr(pipe, lora_base_model),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
|
||||
lora_rank=lora_rank,
|
||||
upcast_dtype=pipe.torch_dtype,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user