kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.29k stars 892 forks source link

RESOURCE_EXHAUSTED: Failed to allocate request for 256.00MiB (268435456B) on device ordinal 0 #209

Closed Hendler closed 2 years ago

Hendler commented 2 years ago

Posting here because 256MiB seems particularly small for a TPU vm.

Command

python ./to_hf_weights.py --input-ckpt gs://[bucket]/finetuned_one_slim/step_72 --config configs/[config].json --output-path gs://[bucket]/finetuned_one_hf --debug

Output

venv/lib/python3.8/site-packages/jax/experimental/maps.py:527: UserWarning: xmap is an experimental feature and probably has bugs!
  warn("xmap is an experimental feature and probably has bugs!")
venv/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:429: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
  warnings.warn(
venv/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:416: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
  warnings.warn(
key shape (1, 2)
in shape (1, 2048)
dp 1
mp 1

Stacktrace

2022-03-07 18:41:42.980600: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2124] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to allocate request for 256.00MiB (268435456B) on device ordinal 0
Traceback (most recent call last):
  File "./to_hf_weights.py", line 488, in <module>
    save_sharded_to_hf_format(input_ckpt, params, output_path, np_dtype, torch_dtype)
  File "./to_hf_weights.py", line 464, in save_sharded_to_hf_format
    network = CausalTransformer(params_local)
  File "/home/jonathan.hendler/finishing-school/mesh_transformer/transformer_shard.py", line 277, in __init__
    self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)
  File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 666, in fun_mapped
    out_flat = xmap_p.bind(fun_flat, *args_flat, **params)
  File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 871, in bind
    return core.map_bind(self, fun, *args, in_axes=in_axes, **params)
  File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/core.py", line 1801, in map_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 874, in process
    return trace.process_xmap(self, fun, tracers, params)
  File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/core.py", line 594, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 703, in xmap_impl
    return xmap_callable(*args)
  File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1524, in execute_replicated
    out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
RuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 256.00MiB (268435456B) on device ordinal 0

Configuration info:

https://github.com/kingoflolz/mesh-transformer-jax/issues/202#issuecomment-1050887576 TPU_VERSION = "v2-alpha"

Python version: Python 3.8.10

Pip freeze

absl-py==1.0.0
aiohttp==3.8.1
aiohttp-cors==0.7.0
aioredis==2.0.1
aiosignal==1.2.0
anyio==3.5.0
asgiref==3.5.0
astunparse==1.6.3
async-timeout==4.0.2
attrs==21.4.0
bcrypt==3.2.0
best-download==0.0.9
black==22.1.0
blessings==1.7
cachetools==4.2.4
certifi==2021.10.8
cffi==1.15.0
chardet==4.0.0
charset-normalizer==2.0.12
chex==0.1.1
clang==5.0
click==8.0.4
cloudpickle==1.3.0
colorama==0.4.4
colorful==0.5.4
cryptography==36.0.1
Cython==0.29.28
DataProperty==0.54.2
datasets==1.15.1
Deprecated==1.2.13
dill==0.3.4
dm-haiku==0.0.5
dm-tree==0.1.6
docker-pycreds==0.4.0
dyNET38==2.1
einops==0.3.2
fabric==2.6.0
fastapi==0.75.0
filelock==3.6.0
Flask==1.1.4
flatbuffers==1.12
frozenlist==1.3.0
fsspec==2022.2.0
ftfy==6.1.1
func-timeout==4.3.5
gast==0.4.0
gitdb==4.0.9
GitPython==3.1.27
google-api-core==2.6.0
google-auth==1.35.0
google-auth-oauthlib==0.4.6
google-cloud-core==1.7.2
google-cloud-storage==1.36.2
google-crc32c==1.3.0
google-pasta==0.2.0
google-resumable-media==1.3.3
googleapis-common-protos==1.55.0
gpustat==0.6.0
grpcio==1.44.0
h11==0.13.0
h5py==3.1.0
huggingface-hub==0.4.0
idna==2.10
importlib-metadata==4.11.2
importlib-resources==5.4.0
iniconfig==1.1.1
invoke==1.6.0
itsdangerous==1.1.0
jax==0.2.28
jaxlib==0.3.0
jieba==0.42.1
Jinja2==2.11.3
jmp==0.0.2
joblib==1.1.0
jsonlines==2.0.0
jsonschema==4.4.0
keras==2.8.0
Keras-Preprocessing==1.1.2
libclang==13.0.0
libtpu-nightly==0.1.dev20220128
lm-dataformat==0.0.20
lm-eval==0.2.0
Markdown==3.3.6
MarkupSafe==2.1.0
mbstrdecoder==1.1.0
mock==4.0.3
msgfy==0.2.0
msgpack==1.0.3
multidict==6.0.2
multiprocess==0.70.12.2
mypy-extensions==0.4.3
nagisa==0.2.7
nltk==3.7
numexpr==2.7.2
numpy==1.22.2
nvidia-ml-py3==7.352.0
oauthlib==3.2.0
openai==0.6.4
opencensus==0.8.0
opencensus-context==0.1.2
opt-einsum==3.3.0
optax==0.0.9
packaging==21.3
pandas==1.4.1
paramiko==2.9.2
pathlib2==2.3.7.post1
pathspec==0.9.0
pathtools==0.1.2
pathvalidate==2.5.0
pathy==0.6.1
platformdirs==2.5.1
pluggy==0.13.1
portalocker==2.4.0
prometheus-client==0.13.1
promise==2.3
protobuf==3.19.4
psutil==5.9.0
py==1.11.0
py-spy==0.3.11
pyarrow==7.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pybind11==2.6.2
pycountry==20.7.3
pycparser==2.21
pydantic==1.9.0
PyNaCl==1.5.0
pyparsing==3.0.7
pyrsistent==0.18.1
pytablewriter==0.58.0
pytest==6.2.3
python-dateutil==2.8.2
pytz==2021.3
PyYAML==6.0
ray==1.4.1
redis==4.1.4
regex==2022.3.2
rehash==1.0.0
requests==2.25.1
requests-oauthlib==1.3.1
rouge-score==0.0.4
rsa==4.8
sacrebleu==1.5.0
sacremoses==0.0.47
scikit-learn==1.0.2
scipy==1.8.0
sentry-sdk==1.5.6
setproctitle==1.2.2
shortuuid==1.0.8
six==1.16.0
smart-open==5.2.1
smmap==5.0.0
sniffio==1.2.0
sqlitedict==1.6.0
starlette==0.17.1
tabledata==1.3.0
tabulate==0.8.9
tcolorpy==0.1.2
tensorboard==2.8.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.8.0
tensorflow-cpu==2.6.3
tensorflow-estimator==2.6.0
tensorflow-io-gcs-filesystem==0.24.0
termcolor==1.1.0
tf-estimator-nightly==2.8.0.dev2021122109
threadpoolctl==3.1.0
tokenizers==0.11.6
toml==0.10.2
tomli==2.0.1
toolz==0.11.2
torch==1.10.2
tqdm==4.63.0
tqdm-multiprocess==0.0.11
transformers==4.17.0
typepy==1.3.0
typer==0.4.0
typing-extensions==4.1.1
ujson==5.1.0
urllib3==1.26.8
uvicorn==0.17.5
wandb==0.12.11
wcwidth==0.2.5
Werkzeug==1.0.1
wrapt==1.12.1
xxhash==3.0.0
yarl==1.7.2
yaspin==2.1.0
zipp==3.7.0
zstandard==0.15.2
Hendler commented 2 years ago

I've resolved this by running the last step to convert to HF on a standard CPU machine with lots of memory.