training script

This commit is contained in:
Artiprocher
2025-05-19 19:02:52 +08:00
parent 675eefa07e
commit 8f10a9c353
5 changed files with 165 additions and 94 deletions

View File

@@ -148,7 +148,10 @@ class BasePipeline(torch.nn.Module):
def freeze_except(self, model_names):
for name, model in self.named_children():
if name not in model_names:
if name in model_names:
model.train()
model.requires_grad_(True)
else:
model.eval()
model.requires_grad_(False)
@@ -214,11 +217,6 @@ class WanVideoPipeline(BasePipeline):
self.model_fn = model_fn_wan_video
def train(self):
super().train()
self.scheduler.set_timesteps(1000, training=True)
def training_loss(self, **inputs):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)