support wan tensor parallel (preview)

This commit is contained in:
Artiprocher
2025-03-17 19:39:45 +08:00
parent 39890f023f
commit 04d03500ff
5 changed files with 147 additions and 9 deletions

View File

@@ -225,7 +225,7 @@ class WanVideoPipeline(BasePipeline):
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
# Initialize noise
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)