diff --git a/diffsynth/diffusion/logger.py b/diffsynth/diffusion/logger.py index ff51e2c..6d2792f 100644 --- a/diffsynth/diffusion/logger.py +++ b/diffsynth/diffusion/logger.py @@ -10,7 +10,7 @@ class ModelLogger: self.num_steps = 0 - def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None): + def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs): self.num_steps += 1 if save_steps is not None and self.num_steps % save_steps == 0: self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index 63cd856..f6e2263 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -40,7 +40,7 @@ def launch_training_task( loss = model(data) accelerator.backward(loss) optimizer.step() - model_logger.on_step_end(accelerator, model, save_steps) + model_logger.on_step_end(accelerator, model, save_steps, loss=loss) scheduler.step() if save_steps is None: model_logger.on_epoch_end(accelerator, model, epoch_id) diff --git a/diffsynth/diffusion/training_module.py b/diffsynth/diffusion/training_module.py index b658866..cc2e79d 100644 --- a/diffsynth/diffusion/training_module.py +++ b/diffsynth/diffusion/training_module.py @@ -150,7 +150,44 @@ class DiffusionTrainingModule(torch.nn.Module): origin_file_pattern = model_id_with_origin_path[split_id + 1:] return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern) + + def auto_detect_lora_target_modules( + self, + model: torch.nn.Module, + search_for_linear=False, + linear_detector=lambda x: min(x.weight.shape) >= 512, + block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1, + name_prefix="", + ): + lora_target_modules = [] + if search_for_linear: + for name, module in model.named_modules(): + module_name = name_prefix + ["", "."][name_prefix != ""] + name + if isinstance(module, torch.nn.Linear) and linear_detector(module): + lora_target_modules.append(module_name) + else: + for name, module in model.named_children(): + module_name = name_prefix + ["", "."][name_prefix != ""] + name + lora_target_modules += self.auto_detect_lora_target_modules( + module, + search_for_linear=block_list_detector(module), + linear_detector=linear_detector, + block_list_detector=block_list_detector, + name_prefix=module_name, + ) + return lora_target_modules + + def parse_lora_target_modules(self, model, lora_target_modules): + if lora_target_modules == "": + print("No LoRA target modules specified. The framework will automatically search for them.") + lora_target_modules = self.auto_detect_lora_target_modules(model) + print(f"LoRA will be patched at {lora_target_modules}.") + else: + lora_target_modules = lora_target_modules.split(",") + return lora_target_modules + + def switch_pipe_to_training_mode( self, pipe, @@ -180,7 +217,7 @@ class DiffusionTrainingModule(torch.nn.Module): return model = self.add_lora_to_model( getattr(pipe, lora_base_model), - target_modules=lora_target_modules.split(","), + target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules), lora_rank=lora_rank, upcast_dtype=pipe.torch_dtype, )