From 3a7502617647317c17a0dc94d6ed0584e8305a83 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 28 Feb 2025 19:43:18 +0800 Subject: [PATCH] update wan input params --- diffsynth/pipelines/wan_video.py | 11 ++++++++++- diffsynth/schedulers/flow_match.py | 4 +++- diffsynth/trainers/text_to_image.py | 2 +- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 82ad76c..7a864e3 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -30,6 +30,8 @@ class WanVideoPipeline(BasePipeline): self.dit: WanModel = None self.vae: WanVideoVAE = None self.model_names = ['text_encoder', 'dit', 'vae'] + self.height_division_factor = 16 + self.width_division_factor = 16 def enable_vram_management(self, num_persistent_param_in_dit=None): @@ -202,17 +204,24 @@ class WanVideoPipeline(BasePipeline): num_frames=81, cfg_scale=5.0, num_inference_steps=50, + sigma_shift=5.0, tiled=True, tile_size=(30, 52), tile_stride=(15, 26), progress_bar_cmd=tqdm, progress_bar_st=None, ): + # Parameter check + height, width = self.check_resize_height_width(height, width) + if num_frames % 4 != 1: + num_frames = (num_frames + 2) // 4 * 4 + 1 + print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.") + # Tiler parameters tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} # Scheduler - self.scheduler.set_timesteps(num_inference_steps, denoising_strength) + self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift) # Initialize noise noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device) diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py index aea6757..fde8849 100644 --- a/diffsynth/schedulers/flow_match.py +++ b/diffsynth/schedulers/flow_match.py @@ -15,7 +15,9 @@ class FlowMatchScheduler(): self.set_timesteps(num_inference_steps) - def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False): + def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None): + if shift is not None: + self.shift = shift sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength if self.extra_one_step: self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index 00a352e..3177191 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -278,7 +278,7 @@ def launch_training_task(model, args): strategy=args.training_strategy, default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, - callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] ) trainer.fit(model=model, train_dataloaders=train_loader)