mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
training script
This commit is contained in:
@@ -148,7 +148,10 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
def freeze_except(self, model_names):
|
||||
for name, model in self.named_children():
|
||||
if name not in model_names:
|
||||
if name in model_names:
|
||||
model.train()
|
||||
model.requires_grad_(True)
|
||||
else:
|
||||
model.eval()
|
||||
model.requires_grad_(False)
|
||||
|
||||
@@ -214,11 +217,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.model_fn = model_fn_wan_video
|
||||
|
||||
|
||||
def train(self):
|
||||
super().train()
|
||||
self.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
Reference in New Issue
Block a user