mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
refine training
This commit is contained in:
@@ -89,7 +89,7 @@ class FlowMatchScheduler():
|
||||
return float(mu)
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
|
||||
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=1024//16*1024//16):
|
||||
sigma_min = 1 / num_inference_steps
|
||||
sigma_max = 1.0
|
||||
num_train_timesteps = 1000
|
||||
|
||||
@@ -29,6 +29,8 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
|
||||
if lora_alpha is None:
|
||||
lora_alpha = lora_rank
|
||||
if isinstance(target_modules, list) and len(target_modules) == 1:
|
||||
target_modules = target_modules[0]
|
||||
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
if upcast_dtype is not None:
|
||||
|
||||
Reference in New Issue
Block a user