venv/lib/python3.8/site-packages/jax/experimental/maps.py:527: UserWarning: xmap is an experimental feature and probably has bugs!
warn("xmap is an experimental feature and probably has bugs!")
venv/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:429: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
warnings.warn(
venv/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:416: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
warnings.warn(
key shape (1, 2)
in shape (1, 2048)
dp 1
mp 1
Stacktrace
2022-03-07 18:41:42.980600: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2124] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to allocate request for 256.00MiB (268435456B) on device ordinal 0
Traceback (most recent call last):
File "./to_hf_weights.py", line 488, in <module>
save_sharded_to_hf_format(input_ckpt, params, output_path, np_dtype, torch_dtype)
File "./to_hf_weights.py", line 464, in save_sharded_to_hf_format
network = CausalTransformer(params_local)
File "/home/jonathan.hendler/finishing-school/mesh_transformer/transformer_shard.py", line 277, in __init__
self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 666, in fun_mapped
out_flat = xmap_p.bind(fun_flat, *args_flat, **params)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 871, in bind
return core.map_bind(self, fun, *args, in_axes=in_axes, **params)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/core.py", line 1801, in map_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 874, in process
return trace.process_xmap(self, fun, tracers, params)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/core.py", line 594, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 703, in xmap_impl
return xmap_callable(*args)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1524, in execute_replicated
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
RuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 256.00MiB (268435456B) on device ordinal 0
Posting here because 256MiB seems particularly small for a TPU vm.
Command
Output
Stacktrace
Configuration info:
https://github.com/kingoflolz/mesh-transformer-jax/issues/202#issuecomment-1050887576
TPU_VERSION = "v2-alpha"
Python version: Python 3.8.10
Pip freeze