mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
update wan input params
This commit is contained in:
@@ -30,6 +30,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.dit: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae']
|
||||
self.height_division_factor = 16
|
||||
self.width_division_factor = 16
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
@@ -202,17 +204,24 @@ class WanVideoPipeline(BasePipeline):
|
||||
num_frames=81,
|
||||
cfg_scale=5.0,
|
||||
num_inference_steps=50,
|
||||
sigma_shift=5.0,
|
||||
tiled=True,
|
||||
tile_size=(30, 52),
|
||||
tile_stride=(15, 26),
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Parameter check
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
if num_frames % 4 != 1:
|
||||
num_frames = (num_frames + 2) // 4 * 4 + 1
|
||||
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
||||
|
||||
# Initialize noise
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
|
||||
|
||||
Reference in New Issue
Block a user