Merge pull request #378 from modelscope/wan-video-params

update wan input params
This commit is contained in:
Zhongjie Duan
2025-02-28 19:52:20 +08:00
committed by GitHub
3 changed files with 14 additions and 3 deletions

View File

@@ -30,6 +30,8 @@ class WanVideoPipeline(BasePipeline):
self.dit: WanModel = None self.dit: WanModel = None
self.vae: WanVideoVAE = None self.vae: WanVideoVAE = None
self.model_names = ['text_encoder', 'dit', 'vae'] self.model_names = ['text_encoder', 'dit', 'vae']
self.height_division_factor = 16
self.width_division_factor = 16
def enable_vram_management(self, num_persistent_param_in_dit=None): def enable_vram_management(self, num_persistent_param_in_dit=None):
@@ -202,17 +204,24 @@ class WanVideoPipeline(BasePipeline):
num_frames=81, num_frames=81,
cfg_scale=5.0, cfg_scale=5.0,
num_inference_steps=50, num_inference_steps=50,
sigma_shift=5.0,
tiled=True, tiled=True,
tile_size=(30, 52), tile_size=(30, 52),
tile_stride=(15, 26), tile_stride=(15, 26),
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
progress_bar_st=None, progress_bar_st=None,
): ):
# Parameter check
height, width = self.check_resize_height_width(height, width)
if num_frames % 4 != 1:
num_frames = (num_frames + 2) // 4 * 4 + 1
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
# Tiler parameters # Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Scheduler # Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength) self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
# Initialize noise # Initialize noise
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device) noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)

View File

@@ -15,7 +15,9 @@ class FlowMatchScheduler():
self.set_timesteps(num_inference_steps) self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False): def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
if shift is not None:
self.shift = shift
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
if self.extra_one_step: if self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]

View File

@@ -278,7 +278,7 @@ def launch_training_task(model, args):
strategy=args.training_strategy, strategy=args.training_strategy,
default_root_dir=args.output_path, default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches, accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
) )
trainer.fit(model=model, train_dataloaders=train_loader) trainer.fit(model=model, train_dataloaders=train_loader)