bminixhofer / zett

Code for Zero-Shot Tokenizer Transfer
https://arxiv.org/abs/2405.07883
101 stars 7 forks source link

OOM on training Mistral hypernet #8

Open kdcyberdude opened 4 weeks ago

kdcyberdude commented 4 weeks ago

Hi @bminixhofer, I am getting OOM with following logs when training a mistral multilingual hypernet. I have tried on this two A100(80GB) as well. Not sure what is wrong!!

I have created a branch containing a script to reproduce this. You can run the ./install script on any instance of vast.ai - https://github.com/bminixhofer/zett/compare/main...kdcyberdude:zett:main

2024-06-10 00:21:14.077507: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.00GiB (rounded to 1073741824)requested by op                                                                         
2024-06-10 00:21:14.078739: W external/tsl/tsl/framework/bfc_allocator.cc:497] ****************************************************************************************************                                                                              
2024-06-10 00:21:14.078795: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1073741824 bytes.                                                             
BufferAssignment OOM Debugging.                                                                                                                                                                                                                                  
BufferAssignment stats:                                                                                                                                                                                                                                          
             parameter allocation:    1.00GiB                                                                                                                                                                                                                    
              constant allocation:         0B                                                                                                                                                                                                                    
        maybe_live_out allocation:    1.00GiB                                                                                                                                                                                                                    
     preallocated temp allocation:         0B                                                                                                                                                                                                                    
                 total allocation:    2.00GiB                                                                                                                                                                                                                    
              total fragmentation:         0B (0.00%)                                                                                                                                                                                                            
Peak buffers:                                                                                                                                                                                                                                                    
        Buffer 1:                                                                                                                                                                                                                                                
                Size: 1.00GiB                                                                                                                                                                                                                                    
                Entry Parameter Subshape: pred[1,32768,32768]                                                                                                                                                                                                    
                ==========================                                                                                                                                                                                                                       

        Buffer 2:                                                                                                                                                                                                                                                
                Size: 1.00GiB                                                                                                                                                                                                                                    
                XLA Label: fusion                                                                                                                                                                                                                                
                Shape: pred[1,1,32768,32768]                                                                                                                                                                                                                     
                ==========================                                                                                                                                                                                                                       

  0%|                                                                                                                                                                                                                                 | 0/100000 [00:37<?, ?it/s]
Traceback (most recent call last):                                                                                                                                                                                                                               
  File "/workspace/zett/train.py", line 1625, in <module>                                                                                                                                                                                                        
    main()                                                                                                                                                                                                                                                       
  File "/workspace/zett/train.py", line 1526, in main                                                                                                                                                                                                            
    state, train_metric = current_step_fn(state, batch)                                                                                                                                                                                                          
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                          
  File "/workspace/zett/train.py", line 1203, in train_step                                                                                                                                                                                                      
    (loss, (lexical_loss, mean_lexical_overlap)), grad = grad_fn(state.params)                                                                                                                                                                                   
                                                         ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                   
  File "/workspace/zett/train.py", line 1116, in compute_loss                                                                                                                                                                                                    
    ) = compute_embeddings_and_logits(                                                                                                                                                                                                                           
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                           
  File "/workspace/zett/train.py", line 1092, in compute_embeddings_and_logits                                                                                                                                                                                   
    logits = model_fn(                                                                                                                                                                                                                                           
             ^^^^^^^^^                                                                                                                                                                                                                                           
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 502, in __call__                                                                                                                           
    outputs = self.module.apply(                                                                                                                                                                                                                                 
              ^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                 
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 677, in __call__                                                                                                                           
    outputs = self.model(                                                                                                                                                                                                                                        
              ^^^^^^^^^^^                                                                                                                                                                                                                                        
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 605, in __call__                                                                                                                           
    outputs = self.layers(                                                                                                                                                                                                                                       
              ^^^^^^^^^^^^                                                                                                                                                                                                                                       
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 556, in __call__                                                                                                                           
    layer_outputs = block(                                                                                                                                                                                                                                       
                    ^^^^^^                                                                                                                                                                                                                                       
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 374, in __call__                                                                                                                           
    outputs = self.self_attn(                                                                                                                                                                                                                                    
              ^^^^^^^^^^^^^^^                                                                                                                                                                                                                                    
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/transformers/models/mistral/modeling_flax_mistral.py", line 241, in setup                                                                                                                              
    casual_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")                                                                                                                                                    
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                    
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/flax/linen/attention.py", line 810, in make_causal_mask                                                                                                                                                
    return make_attention_mask(                                                                                                                                                                                                                                  
           ^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                  
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/flax/linen/attention.py", line 786, in make_attention_mask                                           
    mask = jnp.expand_dims(mask, axis=-3)                                      
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                      
  File "/opt/conda/envs/zett/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 912, in expand_dims                                               
    return lax.expand_dims(a, axis)                                            
           ^^^^^^^^^^^^^^^^^^^^^^^^                                            
ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1073741824 bytes.                                                                       
BufferAssignment OOM Debugging.                                                
BufferAssignment stats:                                                        
             parameter allocation:    1.00GiB                                  
              constant allocation:         0B                                  
        maybe_live_out allocation:    1.00GiB                                  
     preallocated temp allocation:         0B                                  
                 total allocation:    2.00GiB                                  
              total fragmentation:         0B (0.00%)                          
Peak buffers:                                                                  
        Buffer 1:                                                                                
                Size: 1.00GiB                                                                    
                Entry Parameter Subshape: pred[1,32768,32768]                                    
                ==========================                                                       

        Buffer 2:                                                                                
                Size: 1.00GiB                                                                    
                XLA Label: fusion                                                                
                Shape: pred[1,1,32768,32768]                                                                           
                ==========================   
bminixhofer commented 3 weeks ago

Hi, it looks like it is trying to create a very large causal mask due to the high max_position_embeddings. You can try manually lowering the max_position_embeddings to the block_size, which should make it a lot easier on memory (and should be safe to do).

kdcyberdude commented 3 weeks ago

Hi @bminixhofer, Do I need to update max_position_embedding while initializing roberta-base model to 128 in zett/model/init.py

I tried without using pretrained hypernet model as well. It's still giving OOM.

And what is the VRAM requirement for training this on GPU?

Logs

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1718174206.615502 474103 hlo_rematerialization.cc:2946] Can't reduce memory use below -18.34GiB (-19688159409 bytes) by rematerialization; only reduced to 21.55GiB (23140745816 bytes), down from 21.55GiB (23140745816 bytes) originally 2024-06-12 12:06:47.545238: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to allocate request for 112.00MiB (117440512B) on device ordinal 0 BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 16.50GiB constant allocation: 22B maybe_live_out allocation: 21.55GiB preallocated temp allocation: 13.4KiB total allocation: 38.05GiB total fragmentation: 13.4KiB (0.00%) Peak buffers: Buffer 1: Size: 1000.00MiB Entry Parameter Subshape: f32[32000,8192] ========================== Buffer 2: Size: 1000.00MiB XLA Label: fusion Shape: f32[32000,8192] ========================== Buffer 3: Size: 128.00MiB XLA Label: fusion Shape: f32[4096,8192] ========================== Buffer 4: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 5: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 6: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 7: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 8: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 9: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 10: Size: 128.00MiB Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(8192, 4096) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770 XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 11: Size: 128.00MiB XLA Label: fusion Shape: f32[4096,8192] ========================== Buffer 12: Size: 128.00MiB XLA Label: fusion Shape: f32[4096,8192] ========================== Buffer 13: Size: 128.00MiB Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(4096, 8192) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770 XLA Label: fusion Shape: f32[4096,8192] ========================== Buffer 14: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 15: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Traceback (most recent call last): File "/home/kd/anaconda3/envs/zett/lib/python3.11/runpy.py", line 198, in _run_module_as_main return _run_code(code, main_globals, None, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/kd/anaconda3/envs/zett/lib/python3.11/runpy.py", line 88, in _run_code exec(code, run_globals) File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in cli.main() File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main run() File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file runpy.run_path(target, run_name="__main__") File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path return _run_module_code(code, init_globals, run_name, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code _run_code(code, mod_globals, init_globals, File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code exec(code, run_globals) File "/mnt/pi/proj/jun/zett/train.py", line 1629, in main() File "/mnt/pi/proj/jun/zett/train.py", line 848, in main state = jax.jit( ^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 112.00MiB (117440512B) on device ordinal 0 BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 16.50GiB constant allocation: 22B maybe_live_out allocation: 21.55GiB preallocated temp allocation: 13.4KiB total allocation: 38.05GiB total fragmentation: 13.4KiB (0.00%) Peak buffers: Buffer 1: Size: 1000.00MiB Entry Parameter Subshape: f32[32000,8192] ========================== Buffer 2: Size: 1000.00MiB XLA Label: fusion Shape: f32[32000,8192] ========================== Buffer 3: Size: 128.00MiB XLA Label: fusion Shape: f32[4096,8192] ========================== Buffer 4: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 5: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 6: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 7: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 8: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 9: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 10: Size: 128.00MiB Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(8192, 4096) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770 XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 11: Size: 128.00MiB XLA Label: fusion Shape: f32[4096,8192] ========================== Buffer 12: Size: 128.00MiB XLA Label: fusion Shape: f32[4096,8192] ========================== Buffer 13: Size: 128.00MiB Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(4096, 8192) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770 XLA Label: fusion Shape: f32[4096,8192] ========================== Buffer 14: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ========================== Buffer 15: Size: 128.00MiB XLA Label: fusion Shape: f32[8192,4096] ==========================

PS: I am new to JAX.