From 4e3a184199d8c164e975081757755635d0eb5b80 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 14 Oct 2024 10:00:32 +0800 Subject: [PATCH 1/3] update flux training --- diffsynth/schedulers/flow_match.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py index 6678dc5..ab965ee 100644 --- a/diffsynth/schedulers/flow_match.py +++ b/diffsynth/schedulers/flow_match.py @@ -16,17 +16,13 @@ class FlowMatchScheduler(): sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) + self.timesteps = self.sigmas * self.num_train_timesteps if training: - self.timesteps = torch.linspace(1000, 0, num_inference_steps) - - # prepare timestep weights - x = torch.arange(num_inference_steps, dtype=torch.float32) + x = self.timesteps y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) y_shifted = y - y.min() bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) self.linear_timesteps_weights = bsmntw_weighing - else: - self.timesteps = self.sigmas * self.num_train_timesteps def step(self, model_output, timestep, sample, to_final=False): From 58f89ceec974abaa829afec390a3e30e3e0f6fb7 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 14 Oct 2024 17:51:12 +0800 Subject: [PATCH 2/3] update examples --- examples/train/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/train/README.md b/examples/train/README.md index 5c99813..2d9bc94 100644 --- a/examples/train/README.md +++ b/examples/train/README.md @@ -16,7 +16,7 @@ Image Examples of fine-tuned LoRA. The prompt is "一只小狗蹦蹦跳跳,周 ## Install additional packages ``` -pip install peft lightning +pip install peft lightning pandas ``` ## Prepare your dataset From 540c0369885d346abc255c383243c40c54df5fb7 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 14 Oct 2024 18:57:54 +0800 Subject: [PATCH 3/3] add alpha to lora converter --- diffsynth/trainers/text_to_image.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index 6dc24e6..3177474 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -18,6 +18,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): self.learning_rate = learning_rate self.use_gradient_checkpointing = use_gradient_checkpointing self.state_dict_converter = state_dict_converter + self.lora_alpha = None def load_models(self): @@ -34,6 +35,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian"): # Add LoRA to UNet + self.lora_alpha = lora_alpha if init_lora_weights == "kaiming": init_lora_weights = True @@ -94,7 +96,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): if name in trainable_param_names: lora_state_dict[name] = param if self.state_dict_converter is not None: - lora_state_dict = self.state_dict_converter(lora_state_dict) + lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha) checkpoint.update(lora_state_dict)