RWKV_RESCALE_LAYER 999 for music model
This commit is contained in:
parent
8f0fc7db56
commit
861e245062
@ -511,6 +511,9 @@ def get_tokenizer(tokenizer_len: int):
|
||||
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
|
||||
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
|
||||
|
||||
if "midi" in model.lower() or "abc" in model.lower():
|
||||
os.environ["RWKV_RESCALE_LAYER"] = "999"
|
||||
|
||||
# dynamic import to make RWKV_CUDA_ON work
|
||||
if rwkv_beta:
|
||||
from rwkv_pip.beta.model import (
|
||||
|
Loading…
Reference in New Issue
Block a user