mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:39:43 +00:00
65 lines
2.1 KiB
Python
65 lines
2.1 KiB
Python
import torch, os, argparse
|
|
from safetensors.torch import save_file
|
|
|
|
|
|
def load_pl_state_dict(file_path):
|
|
print(f"loading {file_path}")
|
|
state_dict = torch.load(file_path, map_location="cpu")
|
|
trainable_param_names = set(state_dict["trainable_param_names"])
|
|
if "module" in state_dict:
|
|
state_dict = state_dict["module"]
|
|
if "state_dict" in state_dict:
|
|
state_dict = state_dict["state_dict"]
|
|
state_dict_ = {}
|
|
for name, param in state_dict.items():
|
|
if name.startswith("_forward_module."):
|
|
name = name[len("_forward_module."):]
|
|
if name.startswith("unet."):
|
|
name = name[len("unet."):]
|
|
if name in trainable_param_names:
|
|
state_dict_[name] = param
|
|
return state_dict_
|
|
|
|
|
|
def ckpt_to_epochs(ckpt_name):
|
|
return int(ckpt_name.split("=")[1].split("-")[0])
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
|
parser.add_argument(
|
|
"--output_path",
|
|
type=str,
|
|
default="./",
|
|
help="Path to save the model.",
|
|
)
|
|
parser.add_argument(
|
|
"--gamma",
|
|
type=float,
|
|
default=0.9,
|
|
help="Gamma in EMA.",
|
|
)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# args
|
|
args = parse_args()
|
|
folder = args.output_path
|
|
gamma = args.gamma
|
|
|
|
# EMA
|
|
ckpt_list = sorted([(ckpt_to_epochs(ckpt_name), ckpt_name) for ckpt_name in os.listdir(folder) if os.path.isdir(f"{folder}/{ckpt_name}")])
|
|
state_dict_ema = None
|
|
for epochs, ckpt_name in ckpt_list:
|
|
state_dict = load_pl_state_dict(f"{folder}/{ckpt_name}/checkpoint/mp_rank_00_model_states.pt")
|
|
if state_dict_ema is None:
|
|
state_dict_ema = {name: param.float() for name, param in state_dict.items()}
|
|
else:
|
|
for name, param in state_dict.items():
|
|
state_dict_ema[name] = state_dict_ema[name] * gamma + param.float() * (1 - gamma)
|
|
save_path = ckpt_name.replace(".ckpt", "-ema.safetensors")
|
|
print(f"save to {folder}/{save_path}")
|
|
save_file(state_dict_ema, f"{folder}/{save_path}")
|