Open kdcyberdude opened 4 weeks ago
Hi, it looks like it is trying to create a very large causal mask due to the high max_position_embeddings
. You can try manually lowering the max_position_embeddings
to the block_size
, which should make it a lot easier on memory (and should be safe to do).
Hi @bminixhofer, Do I need to update max_position_embedding
while initializing roberta-base
model to 128 in zett/model/init.py
I tried without using pretrained hypernet model as well. It's still giving OOM.
And what is the VRAM requirement for training this on GPU?
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1718174206.615502 474103 hlo_rematerialization.cc:2946] Can't reduce memory use below -18.34GiB (-19688159409 bytes) by rematerialization; only reduced to 21.55GiB (23140745816 bytes), down from 21.55GiB (23140745816 bytes) originally
2024-06-12 12:06:47.545238: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to allocate request for 112.00MiB (117440512B) on device ordinal 0
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 16.50GiB
constant allocation: 22B
maybe_live_out allocation: 21.55GiB
preallocated temp allocation: 13.4KiB
total allocation: 38.05GiB
total fragmentation: 13.4KiB (0.00%)
Peak buffers:
Buffer 1:
Size: 1000.00MiB
Entry Parameter Subshape: f32[32000,8192]
==========================
Buffer 2:
Size: 1000.00MiB
XLA Label: fusion
Shape: f32[32000,8192]
==========================
Buffer 3:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[4096,8192]
==========================
Buffer 4:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 5:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 6:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 7:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 8:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 9:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 10:
Size: 128.00MiB
Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(8192, 4096) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 11:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[4096,8192]
==========================
Buffer 12:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[4096,8192]
==========================
Buffer 13:
Size: 128.00MiB
Operator: op_name="jit(init_state)/jit(main)/broadcast_in_dim[shape=(4096, 8192) broadcast_dimensions=()]" source_file="/mnt/pi/proj/jun/zett/train.py" source_line=770
XLA Label: fusion
Shape: f32[4096,8192]
==========================
Buffer 14:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Buffer 15:
Size: 128.00MiB
XLA Label: fusion
Shape: f32[8192,4096]
==========================
Traceback (most recent call last):
File "/home/kd/anaconda3/envs/zett/lib/python3.11/runpy.py", line 198, in _run_module_as_main
return _run_code(code, main_globals, None,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/kd/anaconda3/envs/zett/lib/python3.11/runpy.py", line 88, in _run_code
exec(code, run_globals)
File "/home/kd/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in
PS: I am new to JAX.
Hi @bminixhofer, I am getting OOM with following logs when training a mistral multilingual hypernet. I have tried on this two A100(80GB) as well. Not sure what is wrong!!
I have created a branch containing a script to reproduce this. You can run the ./install script on any instance of vast.ai - https://github.com/bminixhofer/zett/compare/main...kdcyberdude:zett:main