mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
update doc
This commit is contained in:
@@ -178,15 +178,26 @@ class BasePipeline(torch.nn.Module):
|
||||
def get_vram(self):
|
||||
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
|
||||
|
||||
def get_module(self, model, name):
|
||||
if "." in name:
|
||||
name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
|
||||
if name.isdigit():
|
||||
return self.get_module(model[int(name)], suffix)
|
||||
else:
|
||||
return self.get_module(getattr(model, name), suffix)
|
||||
else:
|
||||
return getattr(model, name)
|
||||
|
||||
def freeze_except(self, model_names):
|
||||
for name, model in self.named_children():
|
||||
if name in model_names:
|
||||
model.train()
|
||||
model.requires_grad_(True)
|
||||
else:
|
||||
model.eval()
|
||||
model.requires_grad_(False)
|
||||
self.eval()
|
||||
self.requires_grad_(False)
|
||||
for name in model_names:
|
||||
module = self.get_module(self, name)
|
||||
if module is None:
|
||||
print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
|
||||
continue
|
||||
module.train()
|
||||
module.requires_grad_(True)
|
||||
|
||||
|
||||
def blend_with_mask(self, base, addition, mask):
|
||||
|
||||
Reference in New Issue
Block a user