mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
Merge pull request #378 from modelscope/wan-video-params
update wan input params
This commit is contained in:
@@ -30,6 +30,8 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.dit: WanModel = None
|
self.dit: WanModel = None
|
||||||
self.vae: WanVideoVAE = None
|
self.vae: WanVideoVAE = None
|
||||||
self.model_names = ['text_encoder', 'dit', 'vae']
|
self.model_names = ['text_encoder', 'dit', 'vae']
|
||||||
|
self.height_division_factor = 16
|
||||||
|
self.width_division_factor = 16
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||||
@@ -202,17 +204,24 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
num_frames=81,
|
num_frames=81,
|
||||||
cfg_scale=5.0,
|
cfg_scale=5.0,
|
||||||
num_inference_steps=50,
|
num_inference_steps=50,
|
||||||
|
sigma_shift=5.0,
|
||||||
tiled=True,
|
tiled=True,
|
||||||
tile_size=(30, 52),
|
tile_size=(30, 52),
|
||||||
tile_stride=(15, 26),
|
tile_stride=(15, 26),
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
progress_bar_st=None,
|
progress_bar_st=None,
|
||||||
):
|
):
|
||||||
|
# Parameter check
|
||||||
|
height, width = self.check_resize_height_width(height, width)
|
||||||
|
if num_frames % 4 != 1:
|
||||||
|
num_frames = (num_frames + 2) // 4 * 4 + 1
|
||||||
|
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
||||||
|
|
||||||
# Tiler parameters
|
# Tiler parameters
|
||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
||||||
|
|
||||||
# Initialize noise
|
# Initialize noise
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
|
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
|
||||||
|
|||||||
@@ -15,7 +15,9 @@ class FlowMatchScheduler():
|
|||||||
self.set_timesteps(num_inference_steps)
|
self.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
|
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
|
||||||
|
if shift is not None:
|
||||||
|
self.shift = shift
|
||||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||||
if self.extra_one_step:
|
if self.extra_one_step:
|
||||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ def launch_training_task(model, args):
|
|||||||
strategy=args.training_strategy,
|
strategy=args.training_strategy,
|
||||||
default_root_dir=args.output_path,
|
default_root_dir=args.output_path,
|
||||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
|
||||||
)
|
)
|
||||||
trainer.fit(model=model, train_dataloaders=train_loader)
|
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user