Closed shuxiaobo closed 2 months ago
What batch size and context length are you using? You may find it helpful to decrease batch size and use the gradient accumulation flag.
hi @karan-dalal , I use the default param of this repo
DATA_PATH=/data/nlp/pile
DATA_NAME="the_pile" # "books3"
# Product should equal 0.5 million
SEQ_LEN=2048
BS=256
# Experiment details
EXP_NAME=ttt_tttmlp_125m
EXP_DIR=ckpt
sudo mkdir -p ${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
cd ../..
# export XLA_PYTHON_CLIENT_PREALLOCATE=false
# export XLA_PYTHON_CLIENT_MEM_FRACTION=0.1
# export XLA_PYTHON_CLIENT_ALLOCATOR=platform
python3 -m ttt.train \
--mesh_dim='!-1,1,1' \
--dtype='fp32' \
--total_steps=4800 \
--save_checkpoint_freq=1000 \
--save_milestone_freq=2000 \
--load_model_config='125m-TTT' \
--update_model_config="dict(seq_modeling_block='ttt_mlp', ttt_base_lr=0.1, ttt_base_lr_init=0.01, ttt_base_lr_warmup=480)" \
--dataset_path=${DATA_PATH} \
--dataset_name=${DATA_NAME} \
--data_cache_path=${CACHE_PATH} \
--seq_length=${SEQ_LEN} \
--global_batch_size=${BS} \
--optimizer.type='adamw' \
--optimizer.adamw_optimizer.weight_decay=0.1 \
--optimizer.adamw_optimizer.lr=3e-3 \
--optimizer.adamw_optimizer.end_lr=1e-5 \
--optimizer.adamw_optimizer.lr_warmup_steps=480 \
--optimizer.adamw_optimizer.lr_decay_steps=4800 \
--exp_dir=${EXP_DIR} \
--exp_name=${EXP_NAME}
I am beginner of jax, but i think it's wired OOM to run a 125M model on 80G A800, when I turn the mesh to '!-1,1,4
which turn on the MP, the same preallocate memory is encounted, such as Buffer 14: Size: 6.00GiB
, it should be 1/4 memory used which due to model parallel reduce the activations memory,
this is GPU utilization when model training
You should not need MP for a 125M model. Try using mesh shape -1, 1, 1 and increasing grad accum (accum_steps
) until it fits.
You should not need MP for a 125M model. Try using mesh shape -1, 1, 1 and increasing grad accum (
accum_steps
) until it fits.
Thanks for your suggestion, I've tried accum_steps=2
, it works very well now :)
and I also tried MP=2 which mesh_dim=!-1,1,4
still OOM
dtype = fp16
will make model loss take off T_T
Alternative to mesh shape -1, 1, 1
, you can also try mesh shape !1, -1, 1
, and see if you can then use accum_steps=1
to speed up. You can also try dtype=bf16
, which will make activation dtype bf16 and is known to be much more stable than fp16.
Alternative to mesh shape
-1, 1, 1
, you can also try mesh shape!1, -1, 1
, and see if you can then useaccum_steps=1
to speed up. You can also trydtype=bf16
, which will make activation dtype bf16 and is known to be much more stable than fp16.
Get, I'll try them and report result in this issue~
hi, I have install gpu_requiremnts, but when I run
scripts/ttt_mlp/125m.sh
on 8 * 80G A800, it raise OOM, even I setIs there any one can help me to solve this?
logs like this