support wan tensor parallel (preview)

This commit is contained in:
Artiprocher
2025-03-17 19:39:45 +08:00
parent 39890f023f
commit 04d03500ff
5 changed files with 147 additions and 9 deletions

View File

@@ -37,7 +37,7 @@ class FlowMatchScheduler():
self.linear_timesteps_weights = bsmntw_weighing
def step(self, model_output, timestep, sample, to_final=False):
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())