RWKV_RESCALE_LAYER 999 for music model

This commit is contained in:
josc146 2023-12-04 17:51:21 +08:00
parent 8f0fc7db56
commit 861e245062

View File

@ -511,6 +511,9 @@ def get_tokenizer(tokenizer_len: int):
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
if "midi" in model.lower() or "abc" in model.lower():
os.environ["RWKV_RESCALE_LAYER"] = "999"
# dynamic import to make RWKV_CUDA_ON work
if rwkv_beta:
from rwkv_pip.beta.model import (