mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
refine training
This commit is contained in:
@@ -29,6 +29,8 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
|
||||
if lora_alpha is None:
|
||||
lora_alpha = lora_rank
|
||||
if isinstance(target_modules, list) and len(target_modules) == 1:
|
||||
target_modules = target_modules[0]
|
||||
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
if upcast_dtype is not None:
|
||||
|
||||
Reference in New Issue
Block a user