mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #1222 from modelscope/trainer-update
support auto detact lora target modules
This commit is contained in:
@@ -10,7 +10,7 @@ class ModelLogger:
|
|||||||
self.num_steps = 0
|
self.num_steps = 0
|
||||||
|
|
||||||
|
|
||||||
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
|
||||||
self.num_steps += 1
|
self.num_steps += 1
|
||||||
if save_steps is not None and self.num_steps % save_steps == 0:
|
if save_steps is not None and self.num_steps % save_steps == 0:
|
||||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ def launch_training_task(
|
|||||||
loss = model(data)
|
loss = model(data)
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
model_logger.on_step_end(accelerator, model, save_steps)
|
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
if save_steps is None:
|
if save_steps is None:
|
||||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||||
|
|||||||
@@ -150,7 +150,44 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
|
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
|
||||||
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
|
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
|
||||||
|
|
||||||
|
|
||||||
|
def auto_detect_lora_target_modules(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
search_for_linear=False,
|
||||||
|
linear_detector=lambda x: min(x.weight.shape) >= 512,
|
||||||
|
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
|
||||||
|
name_prefix="",
|
||||||
|
):
|
||||||
|
lora_target_modules = []
|
||||||
|
if search_for_linear:
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||||
|
if isinstance(module, torch.nn.Linear) and linear_detector(module):
|
||||||
|
lora_target_modules.append(module_name)
|
||||||
|
else:
|
||||||
|
for name, module in model.named_children():
|
||||||
|
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||||
|
lora_target_modules += self.auto_detect_lora_target_modules(
|
||||||
|
module,
|
||||||
|
search_for_linear=block_list_detector(module),
|
||||||
|
linear_detector=linear_detector,
|
||||||
|
block_list_detector=block_list_detector,
|
||||||
|
name_prefix=module_name,
|
||||||
|
)
|
||||||
|
return lora_target_modules
|
||||||
|
|
||||||
|
|
||||||
|
def parse_lora_target_modules(self, model, lora_target_modules):
|
||||||
|
if lora_target_modules == "":
|
||||||
|
print("No LoRA target modules specified. The framework will automatically search for them.")
|
||||||
|
lora_target_modules = self.auto_detect_lora_target_modules(model)
|
||||||
|
print(f"LoRA will be patched at {lora_target_modules}.")
|
||||||
|
else:
|
||||||
|
lora_target_modules = lora_target_modules.split(",")
|
||||||
|
return lora_target_modules
|
||||||
|
|
||||||
|
|
||||||
def switch_pipe_to_training_mode(
|
def switch_pipe_to_training_mode(
|
||||||
self,
|
self,
|
||||||
pipe,
|
pipe,
|
||||||
@@ -180,7 +217,7 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
return
|
return
|
||||||
model = self.add_lora_to_model(
|
model = self.add_lora_to_model(
|
||||||
getattr(pipe, lora_base_model),
|
getattr(pipe, lora_base_model),
|
||||||
target_modules=lora_target_modules.split(","),
|
target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
|
||||||
lora_rank=lora_rank,
|
lora_rank=lora_rank,
|
||||||
upcast_dtype=pipe.torch_dtype,
|
upcast_dtype=pipe.torch_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user