RWKV_RESCALE_LAYER 999 for music model
This commit is contained in:
parent
8f0fc7db56
commit
861e245062
@ -511,6 +511,9 @@ def get_tokenizer(tokenizer_len: int):
|
|||||||
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
|
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
|
||||||
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
|
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
|
||||||
|
|
||||||
|
if "midi" in model.lower() or "abc" in model.lower():
|
||||||
|
os.environ["RWKV_RESCALE_LAYER"] = "999"
|
||||||
|
|
||||||
# dynamic import to make RWKV_CUDA_ON work
|
# dynamic import to make RWKV_CUDA_ON work
|
||||||
if rwkv_beta:
|
if rwkv_beta:
|
||||||
from rwkv_pip.beta.model import (
|
from rwkv_pip.beta.model import (
|
||||||
|
Loading…
Reference in New Issue
Block a user