2023-07-03 17:41:47 +08:00
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
if __name__ == " __main__ " :
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from pytorch_lightning . utilities import rank_zero_info , rank_zero_only
rank_zero_info ( " ########## work in progress ########## " )
########################################################################################################
#
# example: train a simple L12-D768 RWKV on dummy data
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "" --data_type "dummy" --vocab_size 0 \
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: train a simple L6-D512 RWKV from scratch on enwik8
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser = ArgumentParser ( )
parser . add_argument ( " --load_model " , default = " " , type = str ) # full path, with .pth
2023-07-09 12:32:50 +08:00
parser . add_argument (
" --wandb " , default = " " , type = str
) # wandb project name. if "" then don't use wandb
2023-07-03 17:41:47 +08:00
parser . add_argument ( " --proj_dir " , default = " out " , type = str )
parser . add_argument ( " --random_seed " , default = " -1 " , type = int )
parser . add_argument ( " --data_file " , default = " " , type = str )
parser . add_argument ( " --data_type " , default = " utf-8 " , type = str )
2023-07-09 12:32:50 +08:00
parser . add_argument (
" --vocab_size " , default = 0 , type = int
) # vocab_size = 0 means auto (for char-level LM and .txt data)
2023-07-03 17:41:47 +08:00
parser . add_argument ( " --ctx_len " , default = 1024 , type = int )
2023-07-09 12:32:50 +08:00
parser . add_argument (
" --epoch_steps " , default = 1000 , type = int
) # a mini "epoch" has [epoch_steps] steps
parser . add_argument (
" --epoch_count " , default = 500 , type = int
) # train for this many "epochs". will continue afterwards with lr = lr_final
parser . add_argument (
" --epoch_begin " , default = 0 , type = int
) # if you load a model trained for x "epochs", set epoch_begin = x
parser . add_argument (
" --epoch_save " , default = 5 , type = int
) # save the model every [epoch_save] "epochs"
parser . add_argument (
" --micro_bsz " , default = 12 , type = int
) # micro batch size (batch size per GPU)
2023-07-03 17:41:47 +08:00
parser . add_argument ( " --n_layer " , default = 6 , type = int )
parser . add_argument ( " --n_embd " , default = 512 , type = int )
parser . add_argument ( " --dim_att " , default = 0 , type = int )
parser . add_argument ( " --dim_ffn " , default = 0 , type = int )
2023-07-09 12:32:50 +08:00
parser . add_argument (
" --pre_ffn " , default = 0 , type = int
) # replace first att layer by ffn (sometimes better)
2023-07-03 17:41:47 +08:00
parser . add_argument ( " --head_qk " , default = 0 , type = int ) # my headQK trick
parser . add_argument ( " --tiny_att_dim " , default = 0 , type = int ) # tiny attention dim
2023-07-09 12:32:50 +08:00
parser . add_argument (
" --tiny_att_layer " , default = - 999 , type = int
) # tiny attention @ which layer
2023-07-03 17:41:47 +08:00
2023-07-09 12:32:50 +08:00
parser . add_argument (
" --lr_init " , default = 6e-4 , type = float
) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
2023-07-03 17:41:47 +08:00
parser . add_argument ( " --lr_final " , default = 1e-5 , type = float )
2023-07-09 12:32:50 +08:00
parser . add_argument (
" --warmup_steps " , default = 0 , type = int
) # try 50 if you load a model
2023-07-03 17:41:47 +08:00
parser . add_argument ( " --beta1 " , default = 0.9 , type = float )
2023-07-09 12:32:50 +08:00
parser . add_argument (
" --beta2 " , default = 0.99 , type = float
) # use 0.999 when your model is close to convergence
2023-07-03 17:41:47 +08:00
parser . add_argument ( " --adam_eps " , default = 1e-8 , type = float )
2023-07-09 12:32:50 +08:00
parser . add_argument (
" --grad_cp " , default = 0 , type = int
) # gradient checkpt: saves VRAM, but slower
2023-07-03 17:41:47 +08:00
parser . add_argument ( " --my_pile_stage " , default = 0 , type = int ) # my special pile mode
2023-07-09 12:32:50 +08:00
parser . add_argument (
" --my_pile_shift " , default = - 1 , type = int
) # my special pile mode - text shift
2023-07-03 17:41:47 +08:00
parser . add_argument ( " --my_pile_edecay " , default = 0 , type = int )
2023-07-09 12:32:50 +08:00
parser . add_argument (
" --layerwise_lr " , default = 1 , type = int
) # layerwise lr for faster convergence (but slower it/s)
parser . add_argument (
" --ds_bucket_mb " , default = 200 , type = int
) # deepspeed bucket size in MB. 200 seems enough
2023-07-03 17:41:47 +08:00
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser . add_argument ( " --my_img_version " , default = 0 , type = str )
parser . add_argument ( " --my_img_size " , default = 0 , type = int )
parser . add_argument ( " --my_img_bit " , default = 0 , type = int )
2023-07-09 12:32:50 +08:00
parser . add_argument ( " --my_img_clip " , default = " x " , type = str )
2023-07-03 17:41:47 +08:00
parser . add_argument ( " --my_img_clip_scale " , default = 1 , type = float )
parser . add_argument ( " --my_img_l1_scale " , default = 0 , type = float )
2023-07-09 12:32:50 +08:00
parser . add_argument ( " --my_img_encoder " , default = " x " , type = str )
2023-07-03 17:41:47 +08:00
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser . add_argument ( " --my_sample_len " , default = 0 , type = int )
parser . add_argument ( " --my_ffn_shift " , default = 1 , type = int )
parser . add_argument ( " --my_att_shift " , default = 1 , type = int )
parser . add_argument ( " --my_pos_emb " , default = 0 , type = int )
parser . add_argument ( " --load_partial " , default = 0 , type = int )
parser . add_argument ( " --magic_prime " , default = 0 , type = int )
parser . add_argument ( " --my_qa_mask " , default = 0 , type = int )
2023-07-09 12:32:50 +08:00
parser . add_argument ( " --my_testing " , default = " " , type = str )
2023-07-03 17:41:47 +08:00
parser . add_argument ( " --lora " , action = " store_true " )
parser . add_argument ( " --lora_load " , default = " " , type = str )
parser . add_argument ( " --lora_r " , default = 8 , type = int )
parser . add_argument ( " --lora_alpha " , default = 32 , type = float )
parser . add_argument ( " --lora_dropout " , default = 0.01 , type = float )
parser . add_argument ( " --lora_parts " , default = " att,ln,time " , type = str )
parser = Trainer . add_argparse_args ( parser )
args = parser . parse_args ( )
########################################################################################################
import os , warnings , math , datetime , sys , time , importlib
import numpy as np
import torch
from torch . utils . data import DataLoader
2023-07-09 12:32:50 +08:00
2023-07-03 17:41:47 +08:00
if " deepspeed " in args . strategy :
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
if args . random_seed > = 0 :
2023-07-09 12:32:50 +08:00
print (
f " ########## WARNING: GLOBAL SEED { args . random_seed } THIS WILL AFFECT MULTIGPU SAMPLING ########## \n "
* 3
)
2023-07-03 17:41:47 +08:00
seed_everything ( args . random_seed )
np . set_printoptions ( precision = 4 , suppress = True , linewidth = 200 )
2023-07-09 12:32:50 +08:00
warnings . filterwarnings (
" ignore " , " .*Consider increasing the value of the `num_workers` argument* "
)
warnings . filterwarnings (
" ignore " , " .*The progress bar already tracks a metric with the* "
)
2023-07-03 17:41:47 +08:00
# os.environ["WDS_SHOW_SEED"] = "1"
args . my_timestamp = datetime . datetime . today ( ) . strftime ( " % Y- % m- %d - % H- % M- % S " )
args . enable_checkpointing = False
args . replace_sampler_ddp = False
args . logger = False
args . gradient_clip_val = 1.0
args . num_sanity_val_steps = 0
args . check_val_every_n_epoch = int ( 1e20 )
args . log_every_n_steps = int ( 1e20 )
2023-12-29 12:23:36 +08:00
args . max_epochs = args . epoch_count # -1 continue forever
2023-07-03 17:41:47 +08:00
args . betas = ( args . beta1 , args . beta2 )
args . real_bsz = int ( args . num_nodes ) * int ( args . devices ) * args . micro_bsz
os . environ [ " RWKV_T_MAX " ] = str ( args . ctx_len )
os . environ [ " RWKV_MY_TESTING " ] = args . my_testing
if args . dim_att < = 0 :
args . dim_att = args . n_embd
if args . dim_ffn < = 0 :
args . dim_ffn = args . n_embd * 4
if args . data_type == " wds_img " :
args . run_name = f " v { args . my_img_version } - { args . my_img_size } - { args . my_img_bit } bit- { args . my_img_clip } x { args . my_img_clip_scale } "
args . proj_dir = f " { args . proj_dir } - { args . run_name } "
else :
2023-07-09 12:32:50 +08:00
args . run_name = (
f " { args . vocab_size } ctx { args . ctx_len } L { args . n_layer } D { args . n_embd } "
)
2023-07-03 17:41:47 +08:00
if not os . path . exists ( args . proj_dir ) :
os . makedirs ( args . proj_dir )
if args . my_pile_stage > 0 :
magic_prime_bak = args . magic_prime
if args . ctx_len == 1024 :
args . magic_prime = 324331313
args . epoch_count = 8043
elif args . ctx_len == 2048 :
args . magic_prime = 162165671
args . epoch_count = 4021
elif args . ctx_len == 4096 :
args . magic_prime = 81082817
args . epoch_count = 2010
if args . my_pile_shift < 0 :
if args . ctx_len == 1024 :
args . my_pile_shift = 0
elif args . ctx_len == 2048 :
args . my_pile_shift = 512
elif args . ctx_len == 4096 :
args . my_pile_shift = 768
if magic_prime_bak > 0 :
args . magic_prime = magic_prime_bak
args . epoch_steps = 40320 / / args . real_bsz
assert args . epoch_steps * args . real_bsz == 40320
if args . my_pile_stage == 2 :
assert args . lr_final == args . lr_init
if args . my_pile_stage > = 2 : # find latest saved model
list_p = [ ]
for p in os . listdir ( args . proj_dir ) :
if p . startswith ( " rwkv " ) and p . endswith ( " .pth " ) :
p = ( ( p . split ( " - " ) ) [ 1 ] . split ( " . " ) ) [ 0 ]
if p == " init " :
p = - 1
else :
p = int ( p )
list_p + = [ p ]
list_p . sort ( )
max_p = list_p [ - 1 ]
if len ( list_p ) > 1 :
args . my_pile_prev_p = list_p [ - 2 ] # in case max_p is corrupted
if max_p == - 1 :
args . load_model = f " { args . proj_dir } /rwkv-init.pth "
else :
args . load_model = f " { args . proj_dir } /rwkv- { max_p } .pth "
if args . my_pile_stage == 2 :
args . warmup_steps = 10
else :
args . warmup_steps = 30
args . epoch_begin = max_p + 1
samples_per_epoch = args . epoch_steps * args . real_bsz
tokens_per_epoch = samples_per_epoch * args . ctx_len
rank_zero_info (
f """
############################################################################
#
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
#
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
#
2023-11-08 22:18:49 +08:00
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1}, save every {args.epoch_save} epoch
2023-07-03 17:41:47 +08:00
#
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
#
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
# LoRA = {f'enabled, {args.lora_r} r, {args.lora_alpha} alpha, {args.lora_dropout} dropout, on {args.lora_parts}' if args.lora else 'disabled'}
#
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
#
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
# Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
#
############################################################################
"""
)
rank_zero_info ( str ( vars ( args ) ) + " \n " )
2023-07-09 12:32:50 +08:00
assert args . data_type in [
" utf-8 " ,
" utf-16le " ,
" numpy " ,
" binidx " ,
" dummy " ,
" wds_img " ,
" uint16 " ,
]
2023-07-03 17:41:47 +08:00
if args . lr_final == 0 or args . lr_init == 0 :
2023-07-09 12:32:50 +08:00
rank_zero_info (
" \n \n Note: lr_final = 0 or lr_init = 0. Using linear LR schedule instead. \n \n "
)
2023-07-03 17:41:47 +08:00
assert args . precision in [ " fp32 " , " tf32 " , " fp16 " , " bf16 " ]
os . environ [ " RWKV_FLOAT_MODE " ] = args . precision
if args . precision == " fp32 " :
for i in range ( 10 ) :
2023-07-09 12:32:50 +08:00
rank_zero_info (
" \n \n Note: you are using fp32 (very slow). Try bf16 / tf32 for faster training. \n \n "
)
2023-07-03 17:41:47 +08:00
if args . precision == " fp16 " :
2023-07-09 12:32:50 +08:00
rank_zero_info (
" \n \n Note: you are using fp16 (might overflow). Try bf16 / tf32 for stable training. \n \n "
)
2023-07-03 17:41:47 +08:00
os . environ [ " RWKV_JIT_ON " ] = " 1 "
if " deepspeed_stage_3 " in args . strategy :
os . environ [ " RWKV_JIT_ON " ] = " 0 "
if args . lora and args . grad_cp == 1 :
2023-07-09 12:32:50 +08:00
print (
" !!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it "
)
2023-07-03 17:41:47 +08:00
os . environ [ " RWKV_JIT_ON " ] = " 0 "
torch . backends . cudnn . benchmark = True
torch . backends . cudnn . enabled = True
if args . precision == " fp32 " :
torch . backends . cudnn . allow_tf32 = False
torch . backends . cuda . matmul . allow_tf32 = False
else :
torch . backends . cudnn . allow_tf32 = True
torch . backends . cuda . matmul . allow_tf32 = True
if " 32 " in args . precision :
args . precision = 32
elif args . precision == " fp16 " :
args . precision = 16
else :
args . precision = " bf16 "
########################################################################################################
from src . trainer import train_callback , generate_init_weight
from src . dataset import MyDataset
train_data = MyDataset ( args )
args . vocab_size = train_data . vocab_size
2023-07-09 12:32:50 +08:00
if args . data_type == " wds_img " :
2023-07-03 17:41:47 +08:00
from src . model_img import RWKV_IMG
2023-07-09 12:32:50 +08:00
2023-07-03 17:41:47 +08:00
assert args . lora , " LoRA not yet supported for RWKV_IMG "
model = RWKV_IMG ( args )
else :
from src . model import RWKV , LORA_CONFIG , LoraLinear
2023-07-09 12:32:50 +08:00
2023-07-03 17:41:47 +08:00
if args . lora :
assert args . lora_r > 0 , " LoRA should have its `r` > 0 "
LORA_CONFIG [ " r " ] = args . lora_r
LORA_CONFIG [ " alpha " ] = args . lora_alpha
LORA_CONFIG [ " dropout " ] = args . lora_dropout
2023-07-09 12:32:50 +08:00
LORA_CONFIG [ " parts " ] = set ( str ( args . lora_parts ) . split ( " , " ) )
enable_time_finetune = " time " in LORA_CONFIG [ " parts " ]
enable_ln_finetune = " ln " in LORA_CONFIG [ " parts " ]
2023-07-03 17:41:47 +08:00
model = RWKV ( args )
# only train lora parameters
if args . lora :
model . requires_grad_ ( False )
for name , module in model . named_modules ( ) :
# have to check param name since it may have been wrapped by torchscript
if any ( n . startswith ( " lora_ " ) for n , _ in module . named_parameters ( ) ) :
2023-07-09 12:32:50 +08:00
print ( f " LoRA training module { name } " )
2023-07-03 17:41:47 +08:00
for pname , param in module . named_parameters ( ) :
2023-07-09 12:32:50 +08:00
param . requires_grad = " lora_ " in pname
elif enable_ln_finetune and " .ln " in name :
print ( f " LoRA additionally training module { name } " )
2023-07-03 17:41:47 +08:00
for param in module . parameters ( ) :
param . requires_grad = True
2023-07-09 12:32:50 +08:00
elif enable_time_finetune and any (
2023-12-29 12:23:36 +08:00
n . startswith ( " time " ) for n , _ in module . named_parameters ( )
2023-07-09 12:32:50 +08:00
) :
2023-07-03 17:41:47 +08:00
for pname , param in module . named_parameters ( ) :
if pname . startswith ( " time " ) :
2023-07-09 12:32:50 +08:00
print ( f " LoRA additionally training parameter { pname } " )
2023-07-03 17:41:47 +08:00
param . requires_grad = True
2023-07-09 12:32:50 +08:00
if (
2023-12-29 12:23:36 +08:00
len ( args . load_model ) == 0 or args . my_pile_stage == 1
2023-07-09 12:32:50 +08:00
) : # shall we build the initial weights?
2023-07-03 17:41:47 +08:00
init_weight_name = f " { args . proj_dir } /rwkv-init.pth "
generate_init_weight ( model , init_weight_name ) # save initial weights
args . load_model = init_weight_name
rank_zero_info ( f " ########## Loading { args . load_model } ... ########## " )
try :
load_dict = torch . load ( args . load_model , map_location = " cpu " )
2023-07-09 12:33:29 +08:00
model . load_state_dict ( load_dict , strict = ( not args . lora ) )
2023-07-03 17:41:47 +08:00
except :
rank_zero_info ( f " Bad checkpoint { args . load_model } " )
if args . my_pile_stage > = 2 : # try again using another checkpoint
max_p = args . my_pile_prev_p
if max_p == - 1 :
args . load_model = f " { args . proj_dir } /rwkv-init.pth "
else :
args . load_model = f " { args . proj_dir } /rwkv- { max_p } .pth "
args . epoch_begin = max_p + 1
rank_zero_info ( f " Trying { args . load_model } " )
load_dict = torch . load ( args . load_model , map_location = " cpu " )
2023-07-09 12:33:29 +08:00
model . load_state_dict ( load_dict , strict = ( not args . lora ) )
2023-07-03 17:41:47 +08:00
if args . load_partial == 1 :
load_keys = load_dict . keys ( )
for k in model . state_dict ( ) :
if k not in load_keys :
load_dict [ k ] = model . state_dict ( ) [ k ]
2023-07-09 12:33:29 +08:00
model . load_state_dict ( load_dict , strict = ( not args . lora ) )
2023-07-03 17:41:47 +08:00
# If using LoRA, the LoRA keys might be missing in the original model
2023-07-09 12:33:29 +08:00
# model.load_state_dict(load_dict, strict=(not args.lora))
2023-07-03 17:41:47 +08:00
if os . path . isfile ( args . lora_load ) :
2023-07-09 12:32:50 +08:00
model . load_state_dict (
torch . load ( args . lora_load , map_location = " cpu " ) , strict = False
)
2023-07-03 17:41:47 +08:00
trainer : Trainer = Trainer . from_argparse_args (
args ,
callbacks = [ train_callback ( args ) ] ,
)
2023-07-09 12:32:50 +08:00
if (
2023-12-29 12:23:36 +08:00
args . lr_init > 1e-4
or trainer . world_size * args . micro_bsz * trainer . accumulate_grad_batches < 8
2023-07-09 12:32:50 +08:00
) :
if " I_KNOW_WHAT_IM_DOING " in os . environ :
2023-07-03 17:41:47 +08:00
if trainer . global_rank == 0 :
2023-07-09 12:32:50 +08:00
print ( " !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! " )
print (
f " WARNING: you are using too large LR ( { args . lr_init } > 1e-4) or too small global batch size ( { trainer . world_size } * { args . micro_bsz } * { trainer . accumulate_grad_batches } < 8) "
)
print ( " !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! " )
2023-07-03 17:41:47 +08:00
else :
if trainer . global_rank == 0 :
2023-07-09 12:32:50 +08:00
print ( " !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! " )
print (
f " ERROR: you are using too large LR ( { args . lr_init } > 1e-4) or too small global batch size ( { trainer . world_size } * { args . micro_bsz } * { trainer . accumulate_grad_batches } < 8) "
)
print (
f " Unless you are sure this is what you want, adjust them accordingly "
)
print (
f ' (to suppress this, set environment variable " I_KNOW_WHAT_IM_DOING " ) '
)
print ( " !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! " )
2023-07-03 17:41:47 +08:00
exit ( 0 )
if trainer . global_rank == 0 :
for n in model . state_dict ( ) :
shape = model . state_dict ( ) [ n ] . shape
shape = [ i for i in shape if i != 1 ]
if len ( shape ) > 1 :
print ( f " { str ( shape [ 0 ] ) . ljust ( 5 ) } { str ( shape [ 1 ] ) . ljust ( 5 ) } { n } " )
else :
print ( f " { str ( shape [ 0 ] ) . ljust ( 5 ) } { n } " )
if " deepspeed " in args . strategy :
2023-07-09 12:32:50 +08:00
trainer . strategy . config [ " zero_optimization " ] [ " allgather_bucket_size " ] = (
2023-12-29 12:23:36 +08:00
args . ds_bucket_mb * 1000 * 1000
2023-07-09 12:32:50 +08:00
)
trainer . strategy . config [ " zero_optimization " ] [ " reduce_bucket_size " ] = (
2023-12-29 12:23:36 +08:00
args . ds_bucket_mb * 1000 * 1000
2023-07-09 12:32:50 +08:00
)
2023-07-03 17:41:47 +08:00
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
2023-07-09 12:32:50 +08:00
data_loader = DataLoader (
train_data ,
shuffle = False ,
pin_memory = True ,
batch_size = args . micro_bsz ,
num_workers = 1 ,
persistent_workers = False ,
drop_last = True ,
)
2023-07-03 17:41:47 +08:00
trainer . fit ( model , data_loader )