young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.33k stars 247 forks source link

OOM trying to pretrain llama 7b on v4-256 #98

Open redbrain opened 9 months ago

redbrain commented 9 months ago
Command ```sh python -m EasyLM.models.llama.llama_train \ --mesh_dim='-1,32,1' \ --dtype='fp32' \ --total_steps=250000 \ --log_freq=50 \ --save_model_freq=0 \ --save_milestone_freq=2500 \ --load_llama_config='7b' \ --update_llama_config='' \ --load_dataset_state='' \ --load_checkpoint='' \ --tokenizer.vocab_file='gs://.../tokenizer.model' \ --optimizer.type='adamw' \ --optimizer.adamw_optimizer.weight_decay=0.1 \ --optimizer.adamw_optimizer.lr=3e-4 \ --optimizer.adamw_optimizer.end_lr=3e-5 \ --optimizer.adamw_optimizer.lr_warmup_steps=2000 \ --optimizer.adamw_optimizer.lr_decay_steps=250000 \ --train_dataset.type='json' \ --train_dataset.text_processor.fields='text' \ --train_dataset.json_dataset.path='gs://.../dataset.jsonl' \ --train_dataset.json_dataset.seq_length=2048 \ --train_dataset.json_dataset.batch_size=2048 \ --train_dataset.json_dataset.tokenizer_processes=16 \ --checkpointer.save_optimizer_state=True \ --logger.online=True \ --logger.prefix='devingulliver' \ --logger.project="sl_llama_7b" \ --logger.output_dir="gs://.../output/" \ --logger.wandb_dir="$HOME/experiment_output/sl_llama_7b" ```
Log ``` I1008 02:58:09.536914 139894565414912 mesh_utils.py:282] _create_device_mesh_for_nd_torus assignment: [(1,), (0, 2), ()] 0% 0/250000 [02:56 mlxu.run(main) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main train_state, sharded_rng, metrics = sharded_train_step( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, **params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind return self.bind_with_trace(top_trace, args, params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(*tracers, **params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss out_flat, compiled = _pjit_call_impl_python( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1123, in _pjit_call_impl_python always_lower=False, lowering_platform=None).compile() File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile executable = UnloadedMeshExecutable.from_hlo( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo xla_executable, compile_options = _cached_compilation( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 465, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 163.10G of 30.75G hbm. Exceeded hbm capacity by 132.35G. Total hbm usage >= 164.35G: reserved 1.25G program 163.10G arguments 0B Output size 0B; shares 0B with arguments. Program hbm requirement 163.10G: global 20.85M scoped 1.19M HLO temp 163.08G (100.0% utilization: Unpadded (157.97G) Padded (157.97G), 3.1% fragmentation (5.11G)) Largest program allocations in hbm: 1. Size: 8.00G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108 Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 8.00G XLA label: fusion.8979.remat4 = fusion(fusion.112.remat), kind=kOutput, calls=fused_computation.6065.clone.clone.clone.clone Allocation type: HLO temp ========================== 2. Size: 8.00G Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/22/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569 Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 8.00G XLA label: fusion.1599.remat = fusion(bitcast.2905, reshape.13130), kind=kOutput, calls=fused_computation.1400.clone Allocation type: HLO temp ========================== 3. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1888.remat4 = fusion(fusion.942.remat7.1.remat, copy-done.2890, copy-done.2506, all-gather.568.remat2), kind=kOutput, calls=fused_computation.1611.clone.clone.clone.clone Allocation type: HLO temp ========================== 4. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/29/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1968.remat4 = fusion(fusion.922.remat7.1.remat, copy-done.2812, copy-done.2536, all-gather.638.remat2), kind=kOutput, calls=fused_computation.1651.clone.clone.clone.clone Allocation type: HLO temp ========================== 5. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1882.remat4 = fusion(fusion.944.remat5.1.remat, copy-done.2897, copy-done.2503, copy-done.141), kind=kOutput, calls=fused_computation.1608.clone.clone.clone.clone Allocation type: HLO temp ========================== 6. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1880.remat4 = fusion(fusion.944.remat7.1.remat, copy-done.2897, copy-done.2503, all-gather.561.remat2), kind=kOutput, calls=fused_computation.1607.clone.clone.clone.clone Allocation type: HLO temp ========================== 7. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1890.remat4 = fusion(fusion.942.remat5.1.remat, copy-done.2890, copy-done.2506, copy-done.147), kind=kOutput, calls=fused_computation.1612.clone.clone.clone.clone Allocation type: HLO temp ========================== 8. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1896.remat4 = fusion(fusion.940.remat7.1.remat, copy-done.2875, copy-done.2509, all-gather.575.remat2), kind=kOutput, calls=fused_computation.1615.clone.clone.clone.clone Allocation type: HLO temp ========================== 9. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1898.remat4 = fusion(fusion.940.remat5.1.remat, copy-done.2875, copy-done.2509, copy-done.153), kind=kOutput, calls=fused_computation.1616.clone.clone.clone.clone Allocation type: HLO temp ========================== 10. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1874.remat4 = fusion(fusion.946.remat5.1.remat, copy-done.2904, copy-done.2500, copy-done.135), kind=kOutput, calls=fused_computation.1604.clone.clone.clone.clone Allocation type: HLO temp ========================== 11. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1872.remat4 = fusion(fusion.946.remat7.1.remat, copy-done.2904, copy-done.2500, all-gather.554.remat2), kind=kOutput, calls=fused_computation.1603.clone.clone.clone.clone Allocation type: HLO temp ========================== 12. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1786.remat4 = fusion(fusion.968.remat5.1.remat, copy-done.2768, copy-done.2426, copy-done.62), kind=kOutput, calls=fused_computation.1560.clone.clone.clone.clone Allocation type: HLO temp ========================== 13. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1784.remat4 = fusion(fusion.968.remat7.1.remat, copy-done.2768, copy-done.2426, all-gather.477.remat2), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone Allocation type: HLO temp ========================== 14. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1904.remat4 = fusion(fusion.938.remat7.1.remat, copy-done.2868, copy-done.2512, all-gather.582.remat2), kind=kOutput, calls=fused_computation.1619.clone.clone.clone.clone Allocation type: HLO temp ========================== 15. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1906.remat4 = fusion(fusion.938.remat5.1.remat, copy-done.2868, copy-done.2512, copy-done.159), kind=kOutput, calls=fused_computation.1620.clone.clone.clone.clone Allocation type: HLO temp ========================== 16. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1866.remat4 = fusion(fusion.948.remat5.1.remat, copy-done.2911, copy-done.2497, copy-done.129), kind=kOutput, calls=fused_computation.1600.clone.clone.clone.clone Allocation type: HLO temp ========================== 17. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1864.remat4 = fusion(fusion.948.remat7.1.remat, copy-done.2911, copy-done.2497, all-gather.547.remat2), kind=kOutput, calls=fused_computation.1599.clone.clone.clone.clone Allocation type: HLO temp ========================== 18. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1912.remat4 = fusion(fusion.936.remat7.1.remat, copy-done.2861, copy-done.2515, all-gather.589.remat2), kind=kOutput, calls=fused_computation.1623.clone.clone.clone.clone Allocation type: HLO temp ========================== 19. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1914.remat4 = fusion(fusion.936.remat5.1.remat, copy-done.2861, copy-done.2515, copy-done.165), kind=kOutput, calls=fused_computation.1624.clone.clone.clone.clone Allocation type: HLO temp ========================== 20. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/15/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1858.remat4 = fusion(fusion.950.remat5.1.remat, copy-done.2918, copy-done.2482, copy-done.123), kind=kOutput, calls=fused_computation.1596.clone.clone.clone.clone Allocation type: HLO temp ========================== The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in mlxu.run(main) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main train_state, sharded_rng, metrics = sharded_train_step( jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 163.10G of 30.75G hbm. Exceeded hbm capacity by 132.35G. Total hbm usage >= 164.35G: reserved 1.25G program 163.10G arguments 0B Output size 0B; shares 0B with arguments. Program hbm requirement 163.10G: global 20.85M scoped 1.19M HLO temp 163.08G (100.0% utilization: Unpadded (157.97G) Padded (157.97G), 3.1% fragmentation (5.11G)) Largest program allocations in hbm: 1. Size: 8.00G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108 Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 8.00G XLA label: fusion.8979.remat4 = fusion(fusion.112.remat), kind=kOutput, calls=fused_computation.6065.clone.clone.clone.clone Allocation type: HLO temp ========================== 2. Size: 8.00G Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/22/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569 Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 8.00G XLA label: fusion.1599.remat = fusion(bitcast.2905, reshape.13130), kind=kOutput, calls=fused_computation.1400.clone Allocation type: HLO temp ========================== 3. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1888.remat4 = fusion(fusion.942.remat7.1.remat, copy-done.2890, copy-done.2506, all-gather.568.remat2), kind=kOutput, calls=fused_computation.1611.clone.clone.clone.clone Allocation type: HLO temp ========================== 4. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/29/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1968.remat4 = fusion(fusion.922.remat7.1.remat, copy-done.2812, copy-done.2536, all-gather.638.remat2), kind=kOutput, calls=fused_computation.1651.clone.clone.clone.clone Allocation type: HLO temp ========================== 5. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1882.remat4 = fusion(fusion.944.remat5.1.remat, copy-done.2897, copy-done.2503, copy-done.141), kind=kOutput, calls=fused_computation.1608.clone.clone.clone.clone Allocation type: HLO temp ========================== 6. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1880.remat4 = fusion(fusion.944.remat7.1.remat, copy-done.2897, copy-done.2503, all-gather.561.remat2), kind=kOutput, calls=fused_computation.1607.clone.clone.clone.clone Allocation type: HLO temp ========================== 7. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1890.remat4 = fusion(fusion.942.remat5.1.remat, copy-done.2890, copy-done.2506, copy-done.147), kind=kOutput, calls=fused_computation.1612.clone.clone.clone.clone Allocation type: HLO temp ========================== 8. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1896.remat4 = fusion(fusion.940.remat7.1.remat, copy-done.2875, copy-done.2509, all-gather.575.remat2), kind=kOutput, calls=fused_computation.1615.clone.clone.clone.clone Allocation type: HLO temp ========================== 9. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1898.remat4 = fusion(fusion.940.remat5.1.remat, copy-done.2875, copy-done.2509, copy-done.153), kind=kOutput, calls=fused_computation.1616.clone.clone.clone.clone Allocation type: HLO temp ========================== 10. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1874.remat4 = fusion(fusion.946.remat5.1.remat, copy-done.2904, copy-done.2500, copy-done.135), kind=kOutput, calls=fused_computation.1604.clone.clone.clone.clone Allocation type: HLO temp ========================== 11. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1872.remat4 = fusion(fusion.946.remat7.1.remat, copy-done.2904, copy-done.2500, all-gather.554.remat2), kind=kOutput, calls=fused_computation.1603.clone.clone.clone.clone Allocation type: HLO temp ========================== 12. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1786.remat4 = fusion(fusion.968.remat5.1.remat, copy-done.2768, copy-done.2426, copy-done.62), kind=kOutput, calls=fused_computation.1560.clone.clone.clone.clone Allocation type: HLO temp ========================== 13. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1784.remat4 = fusion(fusion.968.remat7.1.remat, copy-done.2768, copy-done.2426, all-gather.477.remat2), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone Allocation type: HLO temp ========================== 14. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1904.remat4 = fusion(fusion.938.remat7.1.remat, copy-done.2868, copy-done.2512, all-gather.582.remat2), kind=kOutput, calls=fused_computation.1619.clone.clone.clone.clone Allocation type: HLO temp ========================== 15. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1906.remat4 = fusion(fusion.938.remat5.1.remat, copy-done.2868, copy-done.2512, copy-done.159), kind=kOutput, calls=fused_computation.1620.clone.clone.clone.clone Allocation type: HLO temp ========================== 16. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1866.remat4 = fusion(fusion.948.remat5.1.remat, copy-done.2911, copy-done.2497, copy-done.129), kind=kOutput, calls=fused_computation.1600.clone.clone.clone.clone Allocation type: HLO temp ========================== 17. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1864.remat4 = fusion(fusion.948.remat7.1.remat, copy-done.2911, copy-done.2497, all-gather.547.remat2), kind=kOutput, calls=fused_computation.1599.clone.clone.clone.clone Allocation type: HLO temp ========================== 18. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1912.remat4 = fusion(fusion.936.remat7.1.remat, copy-done.2861, copy-done.2515, all-gather.589.remat2), kind=kOutput, calls=fused_computation.1623.clone.clone.clone.clone Allocation type: HLO temp ========================== 19. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1914.remat4 = fusion(fusion.936.remat5.1.remat, copy-done.2861, copy-done.2515, copy-done.165), kind=kOutput, calls=fused_computation.1624.clone.clone.clone.clone Allocation type: HLO temp ========================== 20. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/15/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1858.remat4 = fusion(fusion.950.remat5.1.remat, copy-done.2918, copy-done.2482, copy-done.123), kind=kOutput, calls=fused_computation.1596.clone.clone.clone.clone Allocation type: HLO temp ========================== Traceback (most recent call last): File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in mlxu.run(main) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main train_state, sharded_rng, metrics = sharded_train_step( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, **params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind return self.bind_with_trace(top_trace, args, params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(*tracers, **params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss out_flat, compiled = _pjit_call_impl_python( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1123, in _pjit_call_impl_python always_lower=False, lowering_platform=None).compile() File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile executable = UnloadedMeshExecutable.from_hlo( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo xla_executable, compile_options = _cached_compilation( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 465, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 163.10G of 30.75G hbm. Exceeded hbm capacity by 132.35G. Total hbm usage >= 164.35G: reserved 1.25G program 163.10G arguments 0B Output size 0B; shares 0B with arguments. Program hbm requirement 163.10G: global 20.85M scoped 1.19M HLO temp 163.08G (100.0% utilization: Unpadded (157.97G) Padded (157.97G), 3.1% fragmentation (5.11G)) Largest program allocations in hbm: 1. Size: 8.00G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108 Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 8.00G XLA label: fusion.8979.remat4 = fusion(fusion.112.remat), kind=kOutput, calls=fused_computation.6065.clone.clone.clone.clone Allocation type: HLO temp ========================== 2. Size: 8.00G Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/22/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569 Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 8.00G XLA label: fusion.1599.remat = fusion(bitcast.2905, reshape.13130), kind=kOutput, calls=fused_computation.1400.clone Allocation type: HLO temp ========================== 3. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1888.remat4 = fusion(fusion.942.remat7.1.remat, copy-done.2890, copy-done.2506, all-gather.568.remat2), kind=kOutput, calls=fused_computation.1611.clone.clone.clone.clone Allocation type: HLO temp ========================== 4. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/29/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1968.remat4 = fusion(fusion.922.remat7.1.remat, copy-done.2812, copy-done.2536, all-gather.638.remat2), kind=kOutput, calls=fused_computation.1651.clone.clone.clone.clone Allocation type: HLO temp ========================== 5. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1882.remat4 = fusion(fusion.944.remat5.1.remat, copy-done.2897, copy-done.2503, copy-done.141), kind=kOutput, calls=fused_computation.1608.clone.clone.clone.clone Allocation type: HLO temp ========================== 6. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1880.remat4 = fusion(fusion.944.remat7.1.remat, copy-done.2897, copy-done.2503, all-gather.561.remat2), kind=kOutput, calls=fused_computation.1607.clone.clone.clone.clone Allocation type: HLO temp ========================== 7. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1890.remat4 = fusion(fusion.942.remat5.1.remat, copy-done.2890, copy-done.2506, copy-done.147), kind=kOutput, calls=fused_computation.1612.clone.clone.clone.clone Allocation type: HLO temp ========================== 8. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1896.remat4 = fusion(fusion.940.remat7.1.remat, copy-done.2875, copy-done.2509, all-gather.575.remat2), kind=kOutput, calls=fused_computation.1615.clone.clone.clone.clone Allocation type: HLO temp ========================== 9. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1898.remat4 = fusion(fusion.940.remat5.1.remat, copy-done.2875, copy-done.2509, copy-done.153), kind=kOutput, calls=fused_computation.1616.clone.clone.clone.clone Allocation type: HLO temp ========================== 10. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1874.remat4 = fusion(fusion.946.remat5.1.remat, copy-done.2904, copy-done.2500, copy-done.135), kind=kOutput, calls=fused_computation.1604.clone.clone.clone.clone Allocation type: HLO temp ========================== 11. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1872.remat4 = fusion(fusion.946.remat7.1.remat, copy-done.2904, copy-done.2500, all-gather.554.remat2), kind=kOutput, calls=fused_computation.1603.clone.clone.clone.clone Allocation type: HLO temp ========================== 12. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1786.remat4 = fusion(fusion.968.remat5.1.remat, copy-done.2768, copy-done.2426, copy-done.62), kind=kOutput, calls=fused_computation.1560.clone.clone.clone.clone Allocation type: HLO temp ========================== 13. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1784.remat4 = fusion(fusion.968.remat7.1.remat, copy-done.2768, copy-done.2426, all-gather.477.remat2), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone Allocation type: HLO temp ========================== 14. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1904.remat4 = fusion(fusion.938.remat7.1.remat, copy-done.2868, copy-done.2512, all-gather.582.remat2), kind=kOutput, calls=fused_computation.1619.clone.clone.clone.clone Allocation type: HLO temp ========================== 15. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1906.remat4 = fusion(fusion.938.remat5.1.remat, copy-done.2868, copy-done.2512, copy-done.159), kind=kOutput, calls=fused_computation.1620.clone.clone.clone.clone Allocation type: HLO temp ========================== 16. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1866.remat4 = fusion(fusion.948.remat5.1.remat, copy-done.2911, copy-done.2497, copy-done.129), kind=kOutput, calls=fused_computation.1600.clone.clone.clone.clone Allocation type: HLO temp ========================== 17. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1864.remat4 = fusion(fusion.948.remat7.1.remat, copy-done.2911, copy-done.2497, all-gather.547.remat2), kind=kOutput, calls=fused_computation.1599.clone.clone.clone.clone Allocation type: HLO temp ========================== 18. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1912.remat4 = fusion(fusion.936.remat7.1.remat, copy-done.2861, copy-done.2515, all-gather.589.remat2), kind=kOutput, calls=fused_computation.1623.clone.clone.clone.clone Allocation type: HLO temp ========================== 19. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1914.remat4 = fusion(fusion.936.remat5.1.remat, copy-done.2861, copy-done.2515, copy-done.165), kind=kOutput, calls=fused_computation.1624.clone.clone.clone.clone Allocation type: HLO temp ========================== 20. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/15/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1858.remat4 = fusion(fusion.950.remat5.1.remat, copy-done.2918, copy-done.2482, copy-done.123), kind=kOutput, calls=fused_computation.1596.clone.clone.clone.clone Allocation type: HLO temp ========================== The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in mlxu.run(main) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main train_state, sharded_rng, metrics = sharded_train_step( jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 163.10G of 30.75G hbm. Exceeded hbm capacity by 132.35G. Total hbm usage >= 164.35G: reserved 1.25G program 163.10G arguments 0B Output size 0B; shares 0B with arguments. Program hbm requirement 163.10G: global 20.85M scoped 1.19M HLO temp 163.08G (100.0% utilization: Unpadded (157.97G) Padded (157.97G), 3.1% fragmentation (5.11G)) Largest program allocations in hbm: 1. Size: 8.00G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108 Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 8.00G XLA label: fusion.8979.remat4 = fusion(fusion.112.remat), kind=kOutput, calls=fused_computation.6065.clone.clone.clone.clone Allocation type: HLO temp ========================== 2. Size: 8.00G Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/22/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569 Shape: f32[16,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 8.00G XLA label: fusion.1599.remat = fusion(bitcast.2905, reshape.13130), kind=kOutput, calls=fused_computation.1400.clone Allocation type: HLO temp ========================== 3. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1888.remat4 = fusion(fusion.942.remat7.1.remat, copy-done.2890, copy-done.2506, all-gather.568.remat2), kind=kOutput, calls=fused_computation.1611.clone.clone.clone.clone Allocation type: HLO temp ========================== 4. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/29/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1968.remat4 = fusion(fusion.922.remat7.1.remat, copy-done.2812, copy-done.2536, all-gather.638.remat2), kind=kOutput, calls=fused_computation.1651.clone.clone.clone.clone Allocation type: HLO temp ========================== 5. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1882.remat4 = fusion(fusion.944.remat5.1.remat, copy-done.2897, copy-done.2503, copy-done.141), kind=kOutput, calls=fused_computation.1608.clone.clone.clone.clone Allocation type: HLO temp ========================== 6. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/18/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1880.remat4 = fusion(fusion.944.remat7.1.remat, copy-done.2897, copy-done.2503, all-gather.561.remat2), kind=kOutput, calls=fused_computation.1607.clone.clone.clone.clone Allocation type: HLO temp ========================== 7. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/19/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1890.remat4 = fusion(fusion.942.remat5.1.remat, copy-done.2890, copy-done.2506, copy-done.147), kind=kOutput, calls=fused_computation.1612.clone.clone.clone.clone Allocation type: HLO temp ========================== 8. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1896.remat4 = fusion(fusion.940.remat7.1.remat, copy-done.2875, copy-done.2509, all-gather.575.remat2), kind=kOutput, calls=fused_computation.1615.clone.clone.clone.clone Allocation type: HLO temp ========================== 9. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/20/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1898.remat4 = fusion(fusion.940.remat5.1.remat, copy-done.2875, copy-done.2509, copy-done.153), kind=kOutput, calls=fused_computation.1616.clone.clone.clone.clone Allocation type: HLO temp ========================== 10. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1874.remat4 = fusion(fusion.946.remat5.1.remat, copy-done.2904, copy-done.2500, copy-done.135), kind=kOutput, calls=fused_computation.1604.clone.clone.clone.clone Allocation type: HLO temp ========================== 11. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/17/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1872.remat4 = fusion(fusion.946.remat7.1.remat, copy-done.2904, copy-done.2500, all-gather.554.remat2), kind=kOutput, calls=fused_computation.1603.clone.clone.clone.clone Allocation type: HLO temp ========================== 12. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1786.remat4 = fusion(fusion.968.remat5.1.remat, copy-done.2768, copy-done.2426, copy-done.62), kind=kOutput, calls=fused_computation.1560.clone.clone.clone.clone Allocation type: HLO temp ========================== 13. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1784.remat4 = fusion(fusion.968.remat7.1.remat, copy-done.2768, copy-done.2426, all-gather.477.remat2), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone Allocation type: HLO temp ========================== 14. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1904.remat4 = fusion(fusion.938.remat7.1.remat, copy-done.2868, copy-done.2512, all-gather.582.remat2), kind=kOutput, calls=fused_computation.1619.clone.clone.clone.clone Allocation type: HLO temp ========================== 15. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/21/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1906.remat4 = fusion(fusion.938.remat5.1.remat, copy-done.2868, copy-done.2512, copy-done.159), kind=kOutput, calls=fused_computation.1620.clone.clone.clone.clone Allocation type: HLO temp ========================== 16. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1866.remat4 = fusion(fusion.948.remat5.1.remat, copy-done.2911, copy-done.2497, copy-done.129), kind=kOutput, calls=fused_computation.1600.clone.clone.clone.clone Allocation type: HLO temp ========================== 17. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/16/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1864.remat4 = fusion(fusion.948.remat7.1.remat, copy-done.2911, copy-done.2497, all-gather.547.remat2), kind=kOutput, calls=fused_computation.1599.clone.clone.clone.clone Allocation type: HLO temp ========================== 18. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1912.remat4 = fusion(fusion.936.remat7.1.remat, copy-done.2861, copy-done.2515, all-gather.589.remat2), kind=kOutput, calls=fused_computation.1623.clone.clone.clone.clone Allocation type: HLO temp ========================== 19. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/22/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1914.remat4 = fusion(fusion.936.remat5.1.remat, copy-done.2861, copy-done.2515, copy-done.165), kind=kOutput, calls=fused_computation.1624.clone.clone.clone.clone Allocation type: HLO temp ========================== 20. Size: 1.34G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/15/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[16,2048,11008]{2,1,0:T(8,128)} Unpadded size: 1.34G XLA label: fusion.1858.remat4 = fusion(fusion.950.remat5.1.remat, copy-done.2918, copy-done.2482, copy-done.123), kind=kOutput, calls=fused_computation.1596.clone.clone.clone.clone Allocation type: HLO temp ========================== ```
young-geng commented 9 months ago

This is expected. v4-256 is actually only 128 TPU-v4 chips (weird naming convention due to the fact that the two tensorcores on the same chip are viewed as separate devices before v4), so our OpenLLaMA 7B configuration actually uses v4-512. If you want to train on v4-256, consider using batch size 1024 and --mesh_dim='-1,64,1'.

redbrain commented 9 months ago

This results in NotImplementedError: Failed to find assignment for logical_axis_index 1 of size 64 with remaining assignable mesh [4, 4, 8]. Any clue what went wrong?

redbrain commented 8 months ago

It appears that since a v4-256 has half the chips of a v4-512, the appropriate mesh topology would be -1,32,1. But running it with that mesh and with batch sizes of 1024 and even 512 still produces OOM errors. Any advice on how to fix this?

young-geng commented 8 months ago

Oh, this is a known problem that by default JAX does not want to split a physical axis into multiple logical axes. However, we can force it to do that by specifying --mesh_dim='!-1,64,1'

redbrain commented 8 months ago

Still not working, even with the parameters you suggested for mesh_dim and batch_size.

Full command ```sh export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE' python -m EasyLM.models.llama.llama_train \ --mesh_dim='!-1,64,1' \ --dtype='fp32' \ --total_steps=250000 \ --log_freq=50 \ --save_model_freq=0 \ --save_milestone_freq=2500 \ --load_llama_config='7b' \ --update_llama_config='' \ --load_dataset_state='' \ --load_checkpoint='' \ --tokenizer.vocab_file='gs://.../tokenizer.model' \ --optimizer.type='adamw' \ --optimizer.adamw_optimizer.weight_decay=0.1 \ --optimizer.adamw_optimizer.lr=3e-4 \ --optimizer.adamw_optimizer.end_lr=3e-5 \ --optimizer.adamw_optimizer.lr_warmup_steps=2000 \ --optimizer.adamw_optimizer.lr_decay_steps=250000 \ --train_dataset.type='json' \ --train_dataset.text_processor.fields='text' \ --train_dataset.json_dataset.path='gs://.../slimpajama.jsonl' \ --train_dataset.json_dataset.seq_length=2048 \ --train_dataset.json_dataset.batch_size=1024 \ --train_dataset.json_dataset.tokenizer_processes=16 \ --checkpointer.save_optimizer_state=True \ --logger.online=True \ --logger.prefix='devingulliver' \ --logger.project="slender_llama_7b" \ --logger.output_dir="gs://.../output/" \ --logger.wandb_dir="$HOME/experiment_output/slender_llama_7b" \ |& tee $HOME/output.txt ```
Full output log ``` 0% 0/250000 [02:22 mlxu.run(main) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main train_state, sharded_rng, metrics = sharded_train_step( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, **params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind return self.bind_with_trace(top_trace, args, params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(*tracers, **params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss out_flat, compiled = _pjit_call_impl_python( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1123, in _pjit_call_impl_python always_lower=False, lowering_platform=None).compile() File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile executable = UnloadedMeshExecutable.from_hlo( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo xla_executable, compile_options = _cached_compilation( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 465, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 87.08G of 30.75G hbm. Exceeded hbm capacity by 56.33G. Total hbm usage >= 88.33G: reserved 1.25G program 87.08G arguments 0B Output size 0B; shares 0B with arguments. Program hbm requirement 87.08G: global 13.02M scoped 529.0K HLO temp 87.06G (100.0% utilization: Unpadded (86.24G) Padded (86.24G), 0.9% fragmentation (841.66M)) Largest program allocations in hbm: 1. Size: 4.00G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/4/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108 Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 4.00G XLA label: fusion.8917.remat7 = fusion(fusion.130.remat6), kind=kOutput, calls=fused_computation.6051.clone.clone.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 2. Size: 4.00G Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569 Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 4.00G XLA label: fusion.136 = fusion(get-tuple-element.5938, copy-done.1220, bitcast.6320, copy-done.206), kind=kOutput, calls=fused_computation.136 Allocation type: HLO temp ========================== 3. Size: 4.00G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/0/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108 Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 4.00G XLA label: fusion.8913.remat5 = fusion(fusion.134.remat2), kind=kOutput, calls=fused_computation.6047.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 4. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/30/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1977.remat4 = fusion(fusion.926.remat5.1.remat, all-gather.623.remat2, param.536, fusion.6437), kind=kOutput, calls=fused_computation.1655.clone.clone.clone.clone Allocation type: HLO temp ========================== 5. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/31/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1985.remat4 = fusion(fusion.924.remat5.1.remat, all-gather.630.remat2, param.545, fusion.6439), kind=kOutput, calls=fused_computation.1659.clone.clone.clone.clone Allocation type: HLO temp ========================== 6. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2799.remat2 = fusion(fusion.1057, all-gather.748.remat2), kind=kOutput, calls=fused_computation.2158.clone.clone Allocation type: HLO temp ========================== 7. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/30/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2559.remat5 = fusion(fusion.1207, copy-done.454), kind=kOutput, calls=fused_computation.1918.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 8. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/1/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1745.remat3 = fusion(fusion.984.remat3, copy-done.464, copy-done.2938, copy-done.2521), kind=kOutput, calls=fused_computation.1539.clone.clone.clone Allocation type: HLO temp ========================== 9. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/1/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2791.remat2 = fusion(fusion.1062, all-gather.741.remat2), kind=kOutput, calls=fused_computation.2150.clone.clone Allocation type: HLO temp ========================== 10. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/29/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2567.remat5 = fusion(fusion.1202, copy-done.452), kind=kOutput, calls=fused_computation.1926.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 11. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/2/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1753.remat3 = fusion(fusion.982.remat3, copy-done.462, copy-done.2888, copy-done.2518), kind=kOutput, calls=fused_computation.1543.clone.clone.clone Allocation type: HLO temp ========================== 12. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/2/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2783.remat2 = fusion(fusion.1067, all-gather.734.remat2), kind=kOutput, calls=fused_computation.2142.clone.clone Allocation type: HLO temp ========================== 13. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1759.remat3 = fusion(fusion.980.remat3, all-gather.729.remat2, copy-done.2827, copy-done.2527), kind=kOutput, calls=fused_computation.1546.clone.clone.clone Allocation type: HLO temp ========================== 14. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1783.remat4 = fusion(fusion.974.remat7.1.remat, all-gather.456.remat2, copy-done.2801, copy-done.2540), kind=kOutput, calls=fused_computation.1558.clone.clone.clone.clone Allocation type: HLO temp ========================== 15. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1785.remat4 = fusion(fusion.974.remat5.1.remat, all-gather.455.remat2, param.320, fusion.6389), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone Allocation type: HLO temp ========================== 16. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1761.remat3 = fusion(fusion.980.remat3, copy-done.460, copy-done.2826, copy-done.2527), kind=kOutput, calls=fused_computation.1547.clone.clone.clone Allocation type: HLO temp ========================== 17. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1799.remat4 = fusion(fusion.970.remat7.1.remat, all-gather.470.remat2, copy-done.2793, copy-done.2548), kind=kOutput, calls=fused_computation.1566.clone.clone.clone.clone Allocation type: HLO temp ========================== 18. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1801.remat4 = fusion(fusion.970.remat5.1.remat, all-gather.469.remat2, param.338, fusion.6393), kind=kOutput, calls=fused_computation.1567.clone.clone.clone.clone Allocation type: HLO temp ========================== 19. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/9/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1809.remat4 = fusion(fusion.968.remat5.1.remat, all-gather.476.remat2, param.347, fusion.6395), kind=kOutput, calls=fused_computation.1571.clone.clone.clone.clone Allocation type: HLO temp ========================== 20. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/3/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2775.remat2 = fusion(fusion.1072, all-gather.727.remat2), kind=kOutput, calls=fused_computation.2134.clone.clone Allocation type: HLO temp ========================== The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in mlxu.run(main) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main train_state, sharded_rng, metrics = sharded_train_step( jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 87.08G of 30.75G hbm. Exceeded hbm capacity by 56.33G. Total hbm usage >= 88.33G: reserved 1.25G program 87.08G arguments 0B Output size 0B; shares 0B with arguments. Program hbm requirement 87.08G: global 13.02M scoped 529.0K HLO temp 87.06G (100.0% utilization: Unpadded (86.24G) Padded (86.24G), 0.9% fragmentation (841.66M)) Largest program allocations in hbm: 1. Size: 4.00G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/4/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108 Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 4.00G XLA label: fusion.8917.remat7 = fusion(fusion.130.remat6), kind=kOutput, calls=fused_computation.6051.clone.clone.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 2. Size: 4.00G Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569 Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 4.00G XLA label: fusion.136 = fusion(get-tuple-element.5938, copy-done.1220, bitcast.6320, copy-done.206), kind=kOutput, calls=fused_computation.136 Allocation type: HLO temp ========================== 3. Size: 4.00G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/0/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108 Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 4.00G XLA label: fusion.8913.remat5 = fusion(fusion.134.remat2), kind=kOutput, calls=fused_computation.6047.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 4. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/30/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1977.remat4 = fusion(fusion.926.remat5.1.remat, all-gather.623.remat2, param.536, fusion.6437), kind=kOutput, calls=fused_computation.1655.clone.clone.clone.clone Allocation type: HLO temp ========================== 5. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/31/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1985.remat4 = fusion(fusion.924.remat5.1.remat, all-gather.630.remat2, param.545, fusion.6439), kind=kOutput, calls=fused_computation.1659.clone.clone.clone.clone Allocation type: HLO temp ========================== 6. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2799.remat2 = fusion(fusion.1057, all-gather.748.remat2), kind=kOutput, calls=fused_computation.2158.clone.clone Allocation type: HLO temp ========================== 7. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/30/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2559.remat5 = fusion(fusion.1207, copy-done.454), kind=kOutput, calls=fused_computation.1918.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 8. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/1/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1745.remat3 = fusion(fusion.984.remat3, copy-done.464, copy-done.2938, copy-done.2521), kind=kOutput, calls=fused_computation.1539.clone.clone.clone Allocation type: HLO temp ========================== 9. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/1/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2791.remat2 = fusion(fusion.1062, all-gather.741.remat2), kind=kOutput, calls=fused_computation.2150.clone.clone Allocation type: HLO temp ========================== 10. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/29/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2567.remat5 = fusion(fusion.1202, copy-done.452), kind=kOutput, calls=fused_computation.1926.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 11. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/2/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1753.remat3 = fusion(fusion.982.remat3, copy-done.462, copy-done.2888, copy-done.2518), kind=kOutput, calls=fused_computation.1543.clone.clone.clone Allocation type: HLO temp ========================== 12. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/2/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2783.remat2 = fusion(fusion.1067, all-gather.734.remat2), kind=kOutput, calls=fused_computation.2142.clone.clone Allocation type: HLO temp ========================== 13. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1759.remat3 = fusion(fusion.980.remat3, all-gather.729.remat2, copy-done.2827, copy-done.2527), kind=kOutput, calls=fused_computation.1546.clone.clone.clone Allocation type: HLO temp ========================== 14. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1783.remat4 = fusion(fusion.974.remat7.1.remat, all-gather.456.remat2, copy-done.2801, copy-done.2540), kind=kOutput, calls=fused_computation.1558.clone.clone.clone.clone Allocation type: HLO temp ========================== 15. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1785.remat4 = fusion(fusion.974.remat5.1.remat, all-gather.455.remat2, param.320, fusion.6389), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone Allocation type: HLO temp ========================== 16. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1761.remat3 = fusion(fusion.980.remat3, copy-done.460, copy-done.2826, copy-done.2527), kind=kOutput, calls=fused_computation.1547.clone.clone.clone Allocation type: HLO temp ========================== 17. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1799.remat4 = fusion(fusion.970.remat7.1.remat, all-gather.470.remat2, copy-done.2793, copy-done.2548), kind=kOutput, calls=fused_computation.1566.clone.clone.clone.clone Allocation type: HLO temp ========================== 18. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1801.remat4 = fusion(fusion.970.remat5.1.remat, all-gather.469.remat2, param.338, fusion.6393), kind=kOutput, calls=fused_computation.1567.clone.clone.clone.clone Allocation type: HLO temp ========================== 19. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/9/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1809.remat4 = fusion(fusion.968.remat5.1.remat, all-gather.476.remat2, param.347, fusion.6395), kind=kOutput, calls=fused_computation.1571.clone.clone.clone.clone Allocation type: HLO temp ========================== 20. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/3/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2775.remat2 = fusion(fusion.1072, all-gather.727.remat2), kind=kOutput, calls=fused_computation.2134.clone.clone Allocation type: HLO temp ========================== Traceback (most recent call last): File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in mlxu.run(main) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main train_state, sharded_rng, metrics = sharded_train_step( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, **params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind return self.bind_with_trace(top_trace, args, params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(*tracers, **params) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss out_flat, compiled = _pjit_call_impl_python( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1123, in _pjit_call_impl_python always_lower=False, lowering_platform=None).compile() File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile executable = UnloadedMeshExecutable.from_hlo( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo xla_executable, compile_options = _cached_compilation( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/home/redbrain/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 465, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 87.08G of 30.75G hbm. Exceeded hbm capacity by 56.33G. Total hbm usage >= 88.33G: reserved 1.25G program 87.08G arguments 0B Output size 0B; shares 0B with arguments. Program hbm requirement 87.08G: global 13.02M scoped 529.0K HLO temp 87.06G (100.0% utilization: Unpadded (86.24G) Padded (86.24G), 0.9% fragmentation (841.66M)) Largest program allocations in hbm: 1. Size: 4.00G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/4/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108 Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 4.00G XLA label: fusion.8917.remat7 = fusion(fusion.130.remat6), kind=kOutput, calls=fused_computation.6051.clone.clone.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 2. Size: 4.00G Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569 Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 4.00G XLA label: fusion.136 = fusion(get-tuple-element.5938, copy-done.1220, bitcast.6320, copy-done.206), kind=kOutput, calls=fused_computation.136 Allocation type: HLO temp ========================== 3. Size: 4.00G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/0/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108 Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 4.00G XLA label: fusion.8913.remat5 = fusion(fusion.134.remat2), kind=kOutput, calls=fused_computation.6047.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 4. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/30/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1977.remat4 = fusion(fusion.926.remat5.1.remat, all-gather.623.remat2, param.536, fusion.6437), kind=kOutput, calls=fused_computation.1655.clone.clone.clone.clone Allocation type: HLO temp ========================== 5. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/31/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1985.remat4 = fusion(fusion.924.remat5.1.remat, all-gather.630.remat2, param.545, fusion.6439), kind=kOutput, calls=fused_computation.1659.clone.clone.clone.clone Allocation type: HLO temp ========================== 6. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2799.remat2 = fusion(fusion.1057, all-gather.748.remat2), kind=kOutput, calls=fused_computation.2158.clone.clone Allocation type: HLO temp ========================== 7. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/30/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2559.remat5 = fusion(fusion.1207, copy-done.454), kind=kOutput, calls=fused_computation.1918.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 8. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/1/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1745.remat3 = fusion(fusion.984.remat3, copy-done.464, copy-done.2938, copy-done.2521), kind=kOutput, calls=fused_computation.1539.clone.clone.clone Allocation type: HLO temp ========================== 9. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/1/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2791.remat2 = fusion(fusion.1062, all-gather.741.remat2), kind=kOutput, calls=fused_computation.2150.clone.clone Allocation type: HLO temp ========================== 10. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/29/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2567.remat5 = fusion(fusion.1202, copy-done.452), kind=kOutput, calls=fused_computation.1926.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 11. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/2/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1753.remat3 = fusion(fusion.982.remat3, copy-done.462, copy-done.2888, copy-done.2518), kind=kOutput, calls=fused_computation.1543.clone.clone.clone Allocation type: HLO temp ========================== 12. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/2/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2783.remat2 = fusion(fusion.1067, all-gather.734.remat2), kind=kOutput, calls=fused_computation.2142.clone.clone Allocation type: HLO temp ========================== 13. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1759.remat3 = fusion(fusion.980.remat3, all-gather.729.remat2, copy-done.2827, copy-done.2527), kind=kOutput, calls=fused_computation.1546.clone.clone.clone Allocation type: HLO temp ========================== 14. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1783.remat4 = fusion(fusion.974.remat7.1.remat, all-gather.456.remat2, copy-done.2801, copy-done.2540), kind=kOutput, calls=fused_computation.1558.clone.clone.clone.clone Allocation type: HLO temp ========================== 15. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1785.remat4 = fusion(fusion.974.remat5.1.remat, all-gather.455.remat2, param.320, fusion.6389), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone Allocation type: HLO temp ========================== 16. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1761.remat3 = fusion(fusion.980.remat3, copy-done.460, copy-done.2826, copy-done.2527), kind=kOutput, calls=fused_computation.1547.clone.clone.clone Allocation type: HLO temp ========================== 17. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1799.remat4 = fusion(fusion.970.remat7.1.remat, all-gather.470.remat2, copy-done.2793, copy-done.2548), kind=kOutput, calls=fused_computation.1566.clone.clone.clone.clone Allocation type: HLO temp ========================== 18. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1801.remat4 = fusion(fusion.970.remat5.1.remat, all-gather.469.remat2, param.338, fusion.6393), kind=kOutput, calls=fused_computation.1567.clone.clone.clone.clone Allocation type: HLO temp ========================== 19. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/9/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1809.remat4 = fusion(fusion.968.remat5.1.remat, all-gather.476.remat2, param.347, fusion.6395), kind=kOutput, calls=fused_computation.1571.clone.clone.clone.clone Allocation type: HLO temp ========================== 20. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/3/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2775.remat2 = fusion(fusion.1072, all-gather.727.remat2), kind=kOutput, calls=fused_computation.2134.clone.clone Allocation type: HLO temp ========================== The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in mlxu.run(main) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/redbrain/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 235, in main train_state, sharded_rng, metrics = sharded_train_step( jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 87.08G of 30.75G hbm. Exceeded hbm capacity by 56.33G. Total hbm usage >= 88.33G: reserved 1.25G program 87.08G arguments 0B Output size 0B; shares 0B with arguments. Program hbm requirement 87.08G: global 13.02M scoped 529.0K HLO temp 87.06G (100.0% utilization: Unpadded (86.24G) Padded (86.24G), 0.9% fragmentation (841.66M)) Largest program allocations in hbm: 1. Size: 4.00G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/4/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108 Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 4.00G XLA label: fusion.8917.remat7 = fusion(fusion.130.remat6), kind=kOutput, calls=fused_computation.6051.clone.clone.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 2. Size: 4.00G Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/attention/...hqk,...khd->...qhd/dot_general[dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2))) precision=None preferred_element_type=None]" source_file="/home/redbrain/EasyLM/EasyLM/models/llama/llama_model.py" source_line=569 Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 4.00G XLA label: fusion.136 = fusion(get-tuple-element.5938, copy-done.1220, bitcast.6320, copy-done.206), kind=kOutput, calls=fused_computation.136 Allocation type: HLO temp ========================== 3. Size: 4.00G Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/0/attention/sub" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/attention.py" source_line=108 Shape: f32[8,32,2048,2048]{2,3,1,0:T(8,128)} Unpadded size: 4.00G XLA label: fusion.8913.remat5 = fusion(fusion.134.remat2), kind=kOutput, calls=fused_computation.6047.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 4. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/30/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1977.remat4 = fusion(fusion.926.remat5.1.remat, all-gather.623.remat2, param.536, fusion.6437), kind=kOutput, calls=fused_computation.1655.clone.clone.clone.clone Allocation type: HLO temp ========================== 5. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/31/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1985.remat4 = fusion(fusion.924.remat5.1.remat, all-gather.630.remat2, param.545, fusion.6439), kind=kOutput, calls=fused_computation.1659.clone.clone.clone.clone Allocation type: HLO temp ========================== 6. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/0/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2799.remat2 = fusion(fusion.1057, all-gather.748.remat2), kind=kOutput, calls=fused_computation.2158.clone.clone Allocation type: HLO temp ========================== 7. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/30/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2559.remat5 = fusion(fusion.1207, copy-done.454), kind=kOutput, calls=fused_computation.1918.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 8. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/1/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1745.remat3 = fusion(fusion.984.remat3, copy-done.464, copy-done.2938, copy-done.2521), kind=kOutput, calls=fused_computation.1539.clone.clone.clone Allocation type: HLO temp ========================== 9. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/1/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2791.remat2 = fusion(fusion.1062, all-gather.741.remat2), kind=kOutput, calls=fused_computation.2150.clone.clone Allocation type: HLO temp ========================== 10. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/29/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2567.remat5 = fusion(fusion.1202, copy-done.452), kind=kOutput, calls=fused_computation.1926.clone.clone.clone.clone.clone Allocation type: HLO temp ========================== 11. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/2/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1753.remat3 = fusion(fusion.982.remat3, copy-done.462, copy-done.2888, copy-done.2518), kind=kOutput, calls=fused_computation.1543.clone.clone.clone Allocation type: HLO temp ========================== 12. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/2/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2783.remat2 = fusion(fusion.1067, all-gather.734.remat2), kind=kOutput, calls=fused_computation.2142.clone.clone Allocation type: HLO temp ========================== 13. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1759.remat3 = fusion(fusion.980.remat3, all-gather.729.remat2, copy-done.2827, copy-done.2527), kind=kOutput, calls=fused_computation.1546.clone.clone.clone Allocation type: HLO temp ========================== 14. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1783.remat4 = fusion(fusion.974.remat7.1.remat, all-gather.456.remat2, copy-done.2801, copy-done.2540), kind=kOutput, calls=fused_computation.1558.clone.clone.clone.clone Allocation type: HLO temp ========================== 15. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/6/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1785.remat4 = fusion(fusion.974.remat5.1.remat, all-gather.455.remat2, param.320, fusion.6389), kind=kOutput, calls=fused_computation.1559.clone.clone.clone.clone Allocation type: HLO temp ========================== 16. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/3/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1761.remat3 = fusion(fusion.980.remat3, copy-done.460, copy-done.2826, copy-done.2527), kind=kOutput, calls=fused_computation.1547.clone.clone.clone Allocation type: HLO temp ========================== 17. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w3/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1799.remat4 = fusion(fusion.970.remat7.1.remat, all-gather.470.remat2, copy-done.2793, copy-done.2548), kind=kOutput, calls=fused_computation.1566.clone.clone.clone.clone Allocation type: HLO temp ========================== 18. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/8/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1801.remat4 = fusion(fusion.970.remat5.1.remat, all-gather.469.remat2, param.338, fusion.6393), kind=kOutput, calls=fused_computation.1567.clone.clone.clone.clone Allocation type: HLO temp ========================== 19. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/jvp(FlaxLLaMAForCausalLMModule)/transformer/h/9/feed_forward/w1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.1809.remat4 = fusion(fusion.968.remat5.1.remat, all-gather.476.remat2, param.347, fusion.6395), kind=kOutput, calls=fused_computation.1571.clone.clone.clone.clone Allocation type: HLO temp ========================== 20. Size: 688.00M Operator: op_name="pjit(train_step)/jit(main)/transpose(jvp(FlaxLLaMAForCausalLMModule))/transformer/h/3/feed_forward/w2/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/redbrain/.local/lib/python3.10/site-packages/flax/linen/linear.py" source_line=206 Shape: f32[8,2048,11008]{2,1,0:T(8,128)} Unpadded size: 688.00M XLA label: fusion.2775.remat2 = fusion(fusion.1072, all-gather.727.remat2), kind=kOutput, calls=fused_computation.2134.clone.clone Allocation type: HLO temp ========================== ```
young-geng commented 8 months ago

This is quite strange. Maybe XLA is not smart enough for allocating memory. In this case I'd recommend tweaking with batch sizes and mesh size. For example, try a even smaller batch size of 512 or using !-1,128,1 as mesh dim.

0x7o commented 7 hours ago

I was able to run 7B model training on TPU v4-256 with mesh_dim = !-1,16,4 and batch_size = 64 at 115000 tokens per second