google / saxml

Apache License 2.0
118 stars 28 forks source link

GPT-J model conversion failed from pytorch to paxml, throwing OOM error for TPUv3-8 #18

Open confusedgoose627 opened 9 months ago

confusedgoose627 commented 9 months ago

Hi, I am trying to do the serving on gpt-j 6B model using TPUv3-8. For which I am using saxml framework,

The error is coming when I am doing the model conversion from pytorch to pax format which is supported by sax. This is the conversion script:

https://github.com/mlcommons/inference_results_v3.1/blob/main/closed/Google/code/gptj-99/convert_gptj_ckpt.py

The admin and model server is running correctly even I have confirmed that they are communicating by running a sample test query.

The model pickle file is just 22.7 GB so it should acomodate into the TPU cluster. Any idea?

The enviornment pip3 install accelerate pip3 install torch pip3 install transformers pip install paxml==1.1.0)(Although I have build it from its gitrepo)

2024-01-03 05:23:41.411871: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib Loading the base model from EleutherAI/gpt-j-6b transformer.wte.weight (50400, 4096) transformer.h.0.ln_1.weight (4096,) transformer.h.0.ln_1.bias (4096,) transformer.h.0.attn.k_proj.weight (4096, 4096) transformer.h.0.attn.v_proj.weight (4096, 4096) transformer.h.0.attn.q_proj.weight (4096, 4096) transformer.h.0.attn.out_proj.weight (4096, 4096) transformer.h.0.mlp.fc_in.weight (16384, 4096) transformer.h.0.mlp.fc_in.bias (16384,) transformer.h.0.mlp.fc_out.weight (4096, 16384) transformer.h.0.mlp.fc_out.bias (4096,) transformer.h.1.ln_1.weight (4096,) transformer.h.1.ln_1.bias (4096,) transformer.h.1.attn.k_proj.weight (4096, 4096) transformer.h.1.attn.v_proj.weight (4096, 4096) transformer.h.1.attn.q_proj.weight (4096, 4096) transformer.h.1.attn.out_proj.weight (4096, 4096) transformer.h.1.mlp.fc_in.weight (16384, 4096) transformer.h.1.mlp.fc_in.bias (16384,) transformer.h.1.mlp.fc_out.weight (4096, 16384) transformer.h.1.mlp.fc_out.bias (4096,) transformer.h.2.ln_1.weight (4096,) transformer.h.2.ln_1.bias (4096,) transformer.h.2.attn.k_proj.weight (4096, 4096) transformer.h.2.attn.v_proj.weight (4096, 4096) transformer.h.2.attn.q_proj.weight (4096, 4096) transformer.h.2.attn.out_proj.weight (4096, 4096) transformer.h.2.mlp.fc_in.weight (16384, 4096) transformer.h.2.mlp.fc_in.bias (16384,) transformer.h.2.mlp.fc_out.weight (4096, 16384) transformer.h.2.mlp.fc_out.bias (4096,) transformer.h.3.ln_1.weight (4096,) transformer.h.3.ln_1.bias (4096,) transformer.h.3.attn.k_proj.weight (4096, 4096) transformer.h.3.attn.v_proj.weight (4096, 4096) transformer.h.3.attn.q_proj.weight (4096, 4096) transformer.h.3.attn.out_proj.weight (4096, 4096) transformer.h.3.mlp.fc_in.weight (16384, 4096) transformer.h.3.mlp.fc_in.bias (16384,) transformer.h.3.mlp.fc_out.weight (4096, 16384) transformer.h.3.mlp.fc_out.bias (4096,) transformer.h.4.ln_1.weight (4096,) transformer.h.4.ln_1.bias (4096,) transformer.h.4.attn.k_proj.weight (4096, 4096) transformer.h.4.attn.v_proj.weight (4096, 4096) transformer.h.4.attn.q_proj.weight (4096, 4096) transformer.h.4.attn.out_proj.weight (4096, 4096) transformer.h.4.mlp.fc_in.weight (16384, 4096) transformer.h.4.mlp.fc_in.bias (16384,) transformer.h.4.mlp.fc_out.weight (4096, 16384) transformer.h.4.mlp.fc_out.bias (4096,) transformer.h.5.ln_1.weight (4096,) transformer.h.5.ln_1.bias (4096,) transformer.h.5.attn.k_proj.weight (4096, 4096) transformer.h.5.attn.v_proj.weight (4096, 4096) transformer.h.5.attn.q_proj.weight (4096, 4096) transformer.h.5.attn.out_proj.weight (4096, 4096) transformer.h.5.mlp.fc_in.weight (16384, 4096) transformer.h.5.mlp.fc_in.bias (16384,) transformer.h.5.mlp.fc_out.weight (4096, 16384) transformer.h.5.mlp.fc_out.bias (4096,) transformer.h.6.ln_1.weight (4096,) transformer.h.6.ln_1.bias (4096,) transformer.h.6.attn.k_proj.weight (4096, 4096) transformer.h.6.attn.v_proj.weight (4096, 4096) transformer.h.6.attn.q_proj.weight (4096, 4096) transformer.h.6.attn.out_proj.weight (4096, 4096) transformer.h.6.mlp.fc_in.weight (16384, 4096) transformer.h.6.mlp.fc_in.bias (16384,) transformer.h.6.mlp.fc_out.weight (4096, 16384) transformer.h.6.mlp.fc_out.bias (4096,) transformer.h.7.ln_1.weight (4096,) transformer.h.7.ln_1.bias (4096,) transformer.h.7.attn.k_proj.weight (4096, 4096) transformer.h.7.attn.v_proj.weight (4096, 4096) transformer.h.7.attn.q_proj.weight (4096, 4096) transformer.h.7.attn.out_proj.weight (4096, 4096) transformer.h.7.mlp.fc_in.weight (16384, 4096) transformer.h.7.mlp.fc_in.bias (16384,) transformer.h.7.mlp.fc_out.weight (4096, 16384) transformer.h.7.mlp.fc_out.bias (4096,) transformer.h.8.ln_1.weight (4096,) transformer.h.8.ln_1.bias (4096,) transformer.h.8.attn.k_proj.weight (4096, 4096) transformer.h.8.attn.v_proj.weight (4096, 4096) transformer.h.8.attn.q_proj.weight (4096, 4096) transformer.h.8.attn.out_proj.weight (4096, 4096) transformer.h.8.mlp.fc_in.weight (16384, 4096) transformer.h.8.mlp.fc_in.bias (16384,) transformer.h.8.mlp.fc_out.weight (4096, 16384) transformer.h.8.mlp.fc_out.bias (4096,) transformer.h.9.ln_1.weight (4096,) transformer.h.9.ln_1.bias (4096,) transformer.h.9.attn.k_proj.weight (4096, 4096) transformer.h.9.attn.v_proj.weight (4096, 4096) transformer.h.9.attn.q_proj.weight (4096, 4096) transformer.h.9.attn.out_proj.weight (4096, 4096) transformer.h.9.mlp.fc_in.weight (16384, 4096) transformer.h.9.mlp.fc_in.bias (16384,) transformer.h.9.mlp.fc_out.weight (4096, 16384) transformer.h.9.mlp.fc_out.bias (4096,) transformer.h.10.ln_1.weight (4096,) transformer.h.10.ln_1.bias (4096,) transformer.h.10.attn.k_proj.weight (4096, 4096) transformer.h.10.attn.v_proj.weight (4096, 4096) transformer.h.10.attn.q_proj.weight (4096, 4096) transformer.h.10.attn.out_proj.weight (4096, 4096) transformer.h.10.mlp.fc_in.weight (16384, 4096) transformer.h.10.mlp.fc_in.bias (16384,) transformer.h.10.mlp.fc_out.weight (4096, 16384) transformer.h.10.mlp.fc_out.bias (4096,) transformer.h.11.ln_1.weight (4096,) transformer.h.11.ln_1.bias (4096,) transformer.h.11.attn.k_proj.weight (4096, 4096) transformer.h.11.attn.v_proj.weight (4096, 4096) transformer.h.11.attn.q_proj.weight (4096, 4096) transformer.h.11.attn.out_proj.weight (4096, 4096) transformer.h.11.mlp.fc_in.weight (16384, 4096) transformer.h.11.mlp.fc_in.bias (16384,) transformer.h.11.mlp.fc_out.weight (4096, 16384) transformer.h.11.mlp.fc_out.bias (4096,) transformer.h.12.ln_1.weight (4096,) transformer.h.12.ln_1.bias (4096,) transformer.h.12.attn.k_proj.weight (4096, 4096) transformer.h.12.attn.v_proj.weight (4096, 4096) transformer.h.12.attn.q_proj.weight (4096, 4096) transformer.h.12.attn.out_proj.weight (4096, 4096) transformer.h.12.mlp.fc_in.weight (16384, 4096) transformer.h.12.mlp.fc_in.bias (16384,) transformer.h.12.mlp.fc_out.weight (4096, 16384) transformer.h.12.mlp.fc_out.bias (4096,) transformer.h.13.ln_1.weight (4096,) transformer.h.13.ln_1.bias (4096,) transformer.h.13.attn.k_proj.weight (4096, 4096) transformer.h.13.attn.v_proj.weight (4096, 4096) transformer.h.13.attn.q_proj.weight (4096, 4096) transformer.h.13.attn.out_proj.weight (4096, 4096) transformer.h.13.mlp.fc_in.weight (16384, 4096) transformer.h.13.mlp.fc_in.bias (16384,) transformer.h.13.mlp.fc_out.weight (4096, 16384) transformer.h.13.mlp.fc_out.bias (4096,) transformer.h.14.ln_1.weight (4096,) transformer.h.14.ln_1.bias (4096,) transformer.h.14.attn.k_proj.weight (4096, 4096) transformer.h.14.attn.v_proj.weight (4096, 4096) transformer.h.14.attn.q_proj.weight (4096, 4096) transformer.h.14.attn.out_proj.weight (4096, 4096) transformer.h.14.mlp.fc_in.weight (16384, 4096) transformer.h.14.mlp.fc_in.bias (16384,) transformer.h.14.mlp.fc_out.weight (4096, 16384) transformer.h.14.mlp.fc_out.bias (4096,) transformer.h.15.ln_1.weight (4096,) transformer.h.15.ln_1.bias (4096,) transformer.h.15.attn.k_proj.weight (4096, 4096) transformer.h.15.attn.v_proj.weight (4096, 4096) transformer.h.15.attn.q_proj.weight (4096, 4096) transformer.h.15.attn.out_proj.weight (4096, 4096) transformer.h.15.mlp.fc_in.weight (16384, 4096) transformer.h.15.mlp.fc_in.bias (16384,) transformer.h.15.mlp.fc_out.weight (4096, 16384) transformer.h.15.mlp.fc_out.bias (4096,) transformer.h.16.ln_1.weight (4096,) transformer.h.16.ln_1.bias (4096,) transformer.h.16.attn.k_proj.weight (4096, 4096) transformer.h.16.attn.v_proj.weight (4096, 4096) transformer.h.16.attn.q_proj.weight (4096, 4096) transformer.h.16.attn.out_proj.weight (4096, 4096) transformer.h.16.mlp.fc_in.weight (16384, 4096) transformer.h.16.mlp.fc_in.bias (16384,) transformer.h.16.mlp.fc_out.weight (4096, 16384) transformer.h.16.mlp.fc_out.bias (4096,) transformer.h.17.ln_1.weight (4096,) transformer.h.17.ln_1.bias (4096,) transformer.h.17.attn.k_proj.weight (4096, 4096) transformer.h.17.attn.v_proj.weight (4096, 4096) transformer.h.17.attn.q_proj.weight (4096, 4096) transformer.h.17.attn.out_proj.weight (4096, 4096) transformer.h.17.mlp.fc_in.weight (16384, 4096) transformer.h.17.mlp.fc_in.bias (16384,) transformer.h.17.mlp.fc_out.weight (4096, 16384) transformer.h.17.mlp.fc_out.bias (4096,) transformer.h.18.ln_1.weight (4096,) transformer.h.18.ln_1.bias (4096,) transformer.h.18.attn.k_proj.weight (4096, 4096) transformer.h.18.attn.v_proj.weight (4096, 4096) transformer.h.18.attn.q_proj.weight (4096, 4096) transformer.h.18.attn.out_proj.weight (4096, 4096) transformer.h.18.mlp.fc_in.weight (16384, 4096) transformer.h.18.mlp.fc_in.bias (16384,) transformer.h.18.mlp.fc_out.weight (4096, 16384) transformer.h.18.mlp.fc_out.bias (4096,) transformer.h.19.ln_1.weight (4096,) transformer.h.19.ln_1.bias (4096,) transformer.h.19.attn.k_proj.weight (4096, 4096) transformer.h.19.attn.v_proj.weight (4096, 4096) transformer.h.19.attn.q_proj.weight (4096, 4096) transformer.h.19.attn.out_proj.weight (4096, 4096) transformer.h.19.mlp.fc_in.weight (16384, 4096) transformer.h.19.mlp.fc_in.bias (16384,) transformer.h.19.mlp.fc_out.weight (4096, 16384) transformer.h.19.mlp.fc_out.bias (4096,) transformer.h.20.ln_1.weight (4096,) transformer.h.20.ln_1.bias (4096,) transformer.h.20.attn.k_proj.weight (4096, 4096) transformer.h.20.attn.v_proj.weight (4096, 4096) transformer.h.20.attn.q_proj.weight (4096, 4096) transformer.h.20.attn.out_proj.weight (4096, 4096) transformer.h.20.mlp.fc_in.weight (16384, 4096) transformer.h.20.mlp.fc_in.bias (16384,) transformer.h.20.mlp.fc_out.weight (4096, 16384) transformer.h.20.mlp.fc_out.bias (4096,) transformer.h.21.ln_1.weight (4096,) transformer.h.21.ln_1.bias (4096,) transformer.h.21.attn.k_proj.weight (4096, 4096) transformer.h.21.attn.v_proj.weight (4096, 4096) transformer.h.21.attn.q_proj.weight (4096, 4096) transformer.h.21.attn.out_proj.weight (4096, 4096) transformer.h.21.mlp.fc_in.weight (16384, 4096) transformer.h.21.mlp.fc_in.bias (16384,) transformer.h.21.mlp.fc_out.weight (4096, 16384) transformer.h.21.mlp.fc_out.bias (4096,) transformer.h.22.ln_1.weight (4096,) transformer.h.22.ln_1.bias (4096,) transformer.h.22.attn.k_proj.weight (4096, 4096) transformer.h.22.attn.v_proj.weight (4096, 4096) transformer.h.22.attn.q_proj.weight (4096, 4096) transformer.h.22.attn.out_proj.weight (4096, 4096) transformer.h.22.mlp.fc_in.weight (16384, 4096) transformer.h.22.mlp.fc_in.bias (16384,) transformer.h.22.mlp.fc_out.weight (4096, 16384) transformer.h.22.mlp.fc_out.bias (4096,) transformer.h.23.ln_1.weight (4096,) transformer.h.23.ln_1.bias (4096,) transformer.h.23.attn.k_proj.weight (4096, 4096) transformer.h.23.attn.v_proj.weight (4096, 4096) transformer.h.23.attn.q_proj.weight (4096, 4096) transformer.h.23.attn.out_proj.weight (4096, 4096) transformer.h.23.mlp.fc_in.weight (16384, 4096) transformer.h.23.mlp.fc_in.bias (16384,) transformer.h.23.mlp.fc_out.weight (4096, 16384) transformer.h.23.mlp.fc_out.bias (4096,) transformer.h.24.ln_1.weight (4096,) transformer.h.24.ln_1.bias (4096,) transformer.h.24.attn.k_proj.weight (4096, 4096) transformer.h.24.attn.v_proj.weight (4096, 4096) transformer.h.24.attn.q_proj.weight (4096, 4096) transformer.h.24.attn.out_proj.weight (4096, 4096) transformer.h.24.mlp.fc_in.weight (16384, 4096) transformer.h.24.mlp.fc_in.bias (16384,) transformer.h.24.mlp.fc_out.weight (4096, 16384) transformer.h.24.mlp.fc_out.bias (4096,) transformer.h.25.ln_1.weight (4096,) transformer.h.25.ln_1.bias (4096,) transformer.h.25.attn.k_proj.weight (4096, 4096) transformer.h.25.attn.v_proj.weight (4096, 4096) transformer.h.25.attn.q_proj.weight (4096, 4096) transformer.h.25.attn.out_proj.weight (4096, 4096) transformer.h.25.mlp.fc_in.weight (16384, 4096) transformer.h.25.mlp.fc_in.bias (16384,) transformer.h.25.mlp.fc_out.weight (4096, 16384) transformer.h.25.mlp.fc_out.bias (4096,) transformer.h.26.ln_1.weight (4096,) transformer.h.26.ln_1.bias (4096,) transformer.h.26.attn.k_proj.weight (4096, 4096) transformer.h.26.attn.v_proj.weight (4096, 4096) transformer.h.26.attn.q_proj.weight (4096, 4096) transformer.h.26.attn.out_proj.weight (4096, 4096) transformer.h.26.mlp.fc_in.weight (16384, 4096) transformer.h.26.mlp.fc_in.bias (16384,) transformer.h.26.mlp.fc_out.weight (4096, 16384) transformer.h.26.mlp.fc_out.bias (4096,) transformer.h.27.ln_1.weight (4096,) transformer.h.27.ln_1.bias (4096,) transformer.h.27.attn.k_proj.weight (4096, 4096) transformer.h.27.attn.v_proj.weight (4096, 4096) transformer.h.27.attn.q_proj.weight (4096, 4096) transformer.h.27.attn.out_proj.weight (4096, 4096) transformer.h.27.mlp.fc_in.weight (16384, 4096) transformer.h.27.mlp.fc_in.bias (16384,) transformer.h.27.mlp.fc_out.weight (4096, 16384) transformer.h.27.mlp.fc_out.bias (4096,) transformer.ln_f.weight (4096,) transformer.ln_f.bias (4096,) lm_head.weight (50400, 4096) lm_head.bias (50400,) Saving the pax model to pax_6b Traceback (most recent call last): File "/home/arghyajoy627/convert_gptj_ckpt.py", line 192, in convert(args.base_model_path, args.pax_model_path) File "/home/arghyajoy627/convert_gptj_ckpt.py", line 176, in convert jax_states_gda = pjitted_identity(jax_states) File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 248, in cache_miss outs, out_flat, out_tree, args_flat = _python_pjit_helper( File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 195, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, *params) File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/core.py", line 2591, in bind return self.bind_with_trace(top_trace, args, params) File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/core.py", line 362, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/core.py", line 816, in process_primitive return primitive.impl(tracers, params) File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 1246, in _pjit_call_impl compiled = _pjit_lower( File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2836, in compile self._executable = UnloadedMeshExecutable.from_hlo( File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 3048, in from_hlo xla_executable = dispatch.compile_or_get_cached( File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached return backend_compile(backend, serialized_computation, compile_options, File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 471, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 22.54G of 15.48G hbm. Exceeded hbm capacity by 7.06G.

Total hbm usage >= 23.06G: reserved 530.00M program 4.0K arguments 22.54G

Output size 22.54G; shares 0B with arguments.

Program hbm requirement 4.0K: global 4.0K

Largest program allocations in hbm:

  1. Size: 4.0K Shape: u32[8,128]{1,0} Unpadded size: 4.0K XLA label: constant literal Allocation type: global

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/arghyajoy627/convert_gptj_ckpt.py", line 192, in convert(args.base_model_path, args.pax_model_path) File "/home/arghyajoy627/convert_gptj_ckpt.py", line 176, in convert jax_states_gda = pjitted_identity(jax_states) jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 22.54G of 15.48G hbm. Exceeded hbm capacity by 7.06G.

Total hbm usage >= 23.06G: reserved 530.00M program 4.0K arguments 22.54G

Output size 22.54G; shares 0B with arguments.

Program hbm requirement 4.0K: global 4.0K

Largest program allocations in hbm:

  1. Size: 4.0K Shape: u32[8,128]{1,0} Unpadded size: 4.0K XLA label: constant literal Allocation type: global

    @zhihaoshan-google

NoahBPeterson commented 7 months ago

The problem you're having is entire model is being loaded onto one or each TPU, and there is not enough memory on each one to do this.

The conversion script you linked was written for a single GPU (https://github.com/mlcommons/inference_results_v3.1/blob/951b4a7686692d1a0d9b9067a36a7fc26d72ada5/closed/Google/code/gptj-99/convert_gptj_ckpt.py#L154C1-L154C62), not a TPU cluster, for running the conversion, and so it will not shard without modification.

The offending line:

device_mesh = py_utils.create_device_mesh([1, 1, num_gpus])

This creates a device mesh of 1x1 with the number of GPUs (hardcoded to 1).

Try modifying it using the mesh sharding example from the JAX documentation: https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh