diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py index 6678dc5..ab965ee 100644 --- a/diffsynth/schedulers/flow_match.py +++ b/diffsynth/schedulers/flow_match.py @@ -16,17 +16,13 @@ class FlowMatchScheduler(): sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) + self.timesteps = self.sigmas * self.num_train_timesteps if training: - self.timesteps = torch.linspace(1000, 0, num_inference_steps) - - # prepare timestep weights - x = torch.arange(num_inference_steps, dtype=torch.float32) + x = self.timesteps y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) y_shifted = y - y.min() bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) self.linear_timesteps_weights = bsmntw_weighing - else: - self.timesteps = self.sigmas * self.num_train_timesteps def step(self, model_output, timestep, sample, to_final=False): diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index 6dc24e6..3177474 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -18,6 +18,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): self.learning_rate = learning_rate self.use_gradient_checkpointing = use_gradient_checkpointing self.state_dict_converter = state_dict_converter + self.lora_alpha = None def load_models(self): @@ -34,6 +35,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian"): # Add LoRA to UNet + self.lora_alpha = lora_alpha if init_lora_weights == "kaiming": init_lora_weights = True @@ -94,7 +96,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): if name in trainable_param_names: lora_state_dict[name] = param if self.state_dict_converter is not None: - lora_state_dict = self.state_dict_converter(lora_state_dict) + lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha) checkpoint.update(lora_state_dict) diff --git a/examples/train/README.md b/examples/train/README.md index 5c99813..2d9bc94 100644 --- a/examples/train/README.md +++ b/examples/train/README.md @@ -16,7 +16,7 @@ Image Examples of fine-tuned LoRA. The prompt is "一只小狗蹦蹦跳跳,周 ## Install additional packages ``` -pip install peft lightning +pip install peft lightning pandas ``` ## Prepare your dataset