test-time-training / ttt-lm-jax

Official JAX implementation of Learning to (Learn at Test Time): RNNs with Expressive Hidden States
356 stars 27 forks source link

GPU OOM when training ttt_mlp #8

Closed shuxiaobo closed 2 months ago

shuxiaobo commented 2 months ago

hi, I have install gpu_requiremnts, but when I run scripts/ttt_mlp/125m.sh on 8 * 80G A800, it raise OOM, even I set

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".99"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

Is there any one can help me to solve this?

logs like this


Training ttt_transformer_125m:   0%|          | 1/4800 [00:00<?, ?it/s]2024-08-15 00:19:42.875004: W external/xla/xla/service/hlo_rematerialization.cc:2202] Can't reduce memory use below 59.51GiB (63899320320 bytes) by rematerialization; only reduced to 107.26GiB (115168630954 bytes), down from 187.49GiB (201315472062 bytes) originally
warning: Linking two modules of different target triples: 'LLVMDialectModule' is 'nvptx64-nvidia-gpulibs' whereas '' is 'nvptx64-nvidia-cuda'

2024-08-15 00:19:52.347184: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to allocate request for 108.03GiB (116000729080B) on device ordinal 5
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.23GiB
              constant allocation:    1.00MiB
        maybe_live_out allocation:    1.22GiB
     preallocated temp allocation:  108.03GiB
  preallocated temp fragmentation:       124B (0.00%)
                 total allocation:  109.26GiB
              total fragmentation:    1.03MiB (0.00%)
Peak buffers:
    Buffer 1:
        Size: 7.81GiB
        Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(jit(take_along_axis)))/scatter-add[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1, 2), scatter_dims_to_operand_dims=(0, 1, 2)) indices_are_sorted=False unique_indices=False mode=GatherScatterMode.FILL_OR_DROP]" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/infra/jax_utils.py" source_line=262
        XLA Label: fusion
        Shape: f32[32,2048,32001]
        ==========================

    Buffer 2:
        Size: 7.81GiB
        Operator: op_name="pjit(train_step)/jit(main)/jvp(jit(log_softmax))/sub" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/infra/jax_utils.py" source_line=262
        XLA Label: fusion
        Shape: f32[32,2048,32001]
        ==========================

    Buffer 3:
        Size: 6.00GiB
        Operator: op_name="pjit(train_step)/jit(main)/jvp(CausalLM)/model/h/11/seq_modeling_block/exp" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/models/model.py" source_line=596 deduplicated_name="triton_softmax.1"
        XLA Label: fusion
        Shape: f32[32,12,2048,2048]
        ==========================

    Buffer 4:
        Size: 6.00GiB
        Operator: op_name="pjit(train_step)/jit(main)/jvp(CausalLM)/model/h/10/seq_modeling_block/exp" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/models/model.py" source_line=596 deduplicated_name="triton_softmax.1"
        XLA Label: fusion
        Shape: f32[32,12,2048,2048]
        ==========================

    Buffer 5:
        Size: 6.00GiB
        Operator: op_name="pjit(train_step)/jit(main)/jvp(CausalLM)/model/h/9/seq_modeling_block/exp" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/models/model.py" source_line=596 deduplicated_name="triton_softmax.1"
        XLA Label: fusion
        Shape: f32[32,12,2048,2048]
        ==========================

    Buffer 6:
        Size: 6.00GiB
        Operator: op_name="pjit(train_step)/jit(main)/jvp(CausalLM)/model/h/8/seq_modeling_block/exp" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/models/model.py" source_line=596 deduplicated_name="triton_softmax.1"
        XLA Label: fusion
        Shape: f32[32,12,2048,2048]
        ==========================

    Buffer 7:
        Size: 6.00GiB
        Operator: op_name="pjit(train_step)/jit(main)/jvp(CausalLM)/model/h/7/seq_modeling_block/exp" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/models/model.py" source_line=596 deduplicated_name="triton_softmax.1"
        XLA Label: fusion
        Shape: f32[32,12,2048,2048]
        ==========================

    Buffer 8:
        Size: 6.00GiB
        Operator: op_name="pjit(train_step)/jit(main)/jvp(CausalLM)/model/h/6/seq_modeling_block/exp" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/models/model.py" source_line=596 deduplicated_name="triton_softmax.1"
        XLA Label: fusion
        Shape: f32[32,12,2048,2048]
        ==========================

    Buffer 9:
        Size: 6.00GiB
        Operator: op_name="pjit(train_step)/jit(main)/jvp(CausalLM)/model/h/5/seq_modeling_block/exp" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/models/model.py" source_line=596 deduplicated_name="triton_softmax.1"
        XLA Label: fusion
        Shape: f32[32,12,2048,2048]
        ==========================

    Buffer 10:
        Size: 6.00GiB
        Operator: op_name="pjit(train_step)/jit(main)/jvp(CausalLM)/model/h/4/seq_modeling_block/exp" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/models/model.py" source_line=596 deduplicated_name="triton_softmax.1"
        XLA Label: fusion
        Shape: f32[32,12,2048,2048]
        ==========================

    Buffer 11:
        Size: 6.00GiB
        Operator: op_name="pjit(train_step)/jit(main)/jvp(CausalLM)/model/h/3/seq_modeling_block/exp" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/models/model.py" source_line=596 deduplicated_name="triton_softmax.1"
        XLA Label: fusion
        Shape: f32[32,12,2048,2048]
        ==========================

    Buffer 12:
        Size: 6.00GiB
        Operator: op_name="pjit(train_step)/jit(main)/jvp(CausalLM)/model/h/2/seq_modeling_block/exp" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/models/model.py" source_line=596 deduplicated_name="triton_softmax.1"
        XLA Label: fusion
        Shape: f32[32,12,2048,2048]
        ==========================

    Buffer 13:
        Size: 6.00GiB
        Operator: op_name="pjit(train_step)/jit(main)/jvp(CausalLM)/model/h/1/seq_modeling_block/exp" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/models/model.py" source_line=596 deduplicated_name="triton_softmax.1"
        XLA Label: fusion
        Shape: f32[32,12,2048,2048]
        ==========================

    Buffer 14:
        Size: 6.00GiB
        Operator: op_name="pjit(train_step)/jit(main)/jvp(CausalLM)/model/h/0/seq_modeling_block/exp" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/models/model.py" source_line=596 deduplicated_name="triton_softmax.1"
        XLA Label: fusion
        Shape: f32[32,12,2048,2048]
        ==========================

    Buffer 15:
        Size: 3.91GiB
        Operator: op_name="pjit(train_step)/jit(main)/convert_element_type[new_dtype=float16 weak_type=False]" source_file="/apdcephfs_qy3/share_1502809/shaneshu/TTT/ttt-lm-jax/ttt/infra/jax_utils.py" source_line=260
        XLA Label: fusion
        Shape: f16[32,2048,32001]
        ==========================

2024-08-15 00:19:52.354780: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to allocate request for 108.03GiB (116000729080B) on device ordinal 4
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.23GiB
              constant allocation:    1.00MiB
        maybe_live_out allocation:    1.22GiB
     preallocated temp allocation:  108.03GiB
  preallocated temp fragmentation:       124B (0.00%)
                 total allocation:  109.26GiB
              total fragmentation:    1.03MiB (0.00%)
Peak buffers:
    Buffer 1:
....
karan-dalal commented 2 months ago

What batch size and context length are you using? You may find it helpful to decrease batch size and use the gradient accumulation flag.

shuxiaobo commented 2 months ago

hi @karan-dalal , I use the default param of this repo

DATA_PATH=/data/nlp/pile
DATA_NAME="the_pile" # "books3" 

# Product should equal 0.5 million
SEQ_LEN=2048
BS=256

# Experiment details
EXP_NAME=ttt_tttmlp_125m
EXP_DIR=ckpt

sudo mkdir -p ${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
cd ../..

# export XLA_PYTHON_CLIENT_PREALLOCATE=false
# export XLA_PYTHON_CLIENT_MEM_FRACTION=0.1
# export XLA_PYTHON_CLIENT_ALLOCATOR=platform

python3 -m ttt.train \
        --mesh_dim='!-1,1,1' \
        --dtype='fp32' \
        --total_steps=4800 \
        --save_checkpoint_freq=1000 \
        --save_milestone_freq=2000 \
        --load_model_config='125m-TTT' \
        --update_model_config="dict(seq_modeling_block='ttt_mlp', ttt_base_lr=0.1, ttt_base_lr_init=0.01, ttt_base_lr_warmup=480)" \
        --dataset_path=${DATA_PATH} \
        --dataset_name=${DATA_NAME} \
        --data_cache_path=${CACHE_PATH} \
        --seq_length=${SEQ_LEN} \
        --global_batch_size=${BS} \
        --optimizer.type='adamw' \
        --optimizer.adamw_optimizer.weight_decay=0.1 \
        --optimizer.adamw_optimizer.lr=3e-3 \
        --optimizer.adamw_optimizer.end_lr=1e-5 \
        --optimizer.adamw_optimizer.lr_warmup_steps=480 \
        --optimizer.adamw_optimizer.lr_decay_steps=4800 \
        --exp_dir=${EXP_DIR} \
        --exp_name=${EXP_NAME}

I am beginner of jax, but i think it's wired OOM to run a 125M model on 80G A800, when I turn the mesh to '!-1,1,4 which turn on the MP, the same preallocate memory is encounted, such as Buffer 14: Size: 6.00GiB, it should be 1/4 memory used which due to model parallel reduce the activations memory,

shuxiaobo commented 2 months ago
image

this is GPU utilization when model training

karan-dalal commented 2 months ago

You should not need MP for a 125M model. Try using mesh shape -1, 1, 1 and increasing grad accum (accum_steps) until it fits.

shuxiaobo commented 2 months ago

You should not need MP for a 125M model. Try using mesh shape -1, 1, 1 and increasing grad accum (accum_steps) until it fits.

Thanks for your suggestion, I've tried accum_steps=2, it works very well now :)

and I also tried MP=2 which mesh_dim=!-1,1,4 still OOM dtype = fp16 will make model loss take off T_T

LeoXinhaoLee commented 2 months ago

Alternative to mesh shape -1, 1, 1, you can also try mesh shape !1, -1, 1, and see if you can then use accum_steps=1 to speed up. You can also try dtype=bf16, which will make activation dtype bf16 and is known to be much more stable than fp16.

shuxiaobo commented 2 months ago

Alternative to mesh shape -1, 1, 1, you can also try mesh shape !1, -1, 1, and see if you can then use accum_steps=1 to speed up. You can also try dtype=bf16, which will make activation dtype bf16 and is known to be much more stable than fp16.

Get, I'll try them and report result in this issue~