Open MikeMpapa opened 1 month ago
yeah looks like they moved/removed jax.experimental.maps. the container sticks close to JAX head, which we don't.
Once I merge https://github.com/stanford-crfm/haliax/pull/102 and a suitable interval for the package to propagate, you should be able to update to the latest dev version of haliax (probably 308 or 309)
try pip install haliax==1.4.dev310
and see if it fixes
Thanks for your response.
This caused a series of errors that is hard for me trace exactly so I copy the full output for now:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
levanter.config.main(main)()
File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
response = fn(cfg, *args, **kwargs)
File "/levanter/src/levanter/main/train_lm.py", line 215, in main
trainer.train(state, train_loader)
File "/levanter/src/levanter/trainer.py", line 403, in train
for info in self.training_steps(state, train_loader, run_hooks=run_hooks):
File "/levanter/src/levanter/trainer.py", line 386, in training_steps
info = self.train_step(state, example)
File "/levanter/src/levanter/trainer.py", line 370, in train_step
loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 261, in __call__
return self._call(False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/equinox/_module.py", line 1053, in __call__
return self.__func__(self.__self__, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 315, in _call
output_shape = _cached_filter_eval_shape(self._fn, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 546, in _cached_filter_eval_shape
_eval_shape_cache[static] = eqx.filter_eval_shape(fun, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/equinox/_eval_shape.py", line 38, in filter_eval_shape
dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
File "/levanter/src/levanter/trainer.py", line 498, in _train_step
loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)
File "/levanter/src/levanter/trainer.py", line 515, in _compute_gradients_microbatched
return grad_fn(model, *batch, **batch_kwargs)
File "/levanter/src/levanter/grad_accum.py", line 92, in wrapped_fn
r_shape = eqx.filter_eval_shape(fn, *args, **kwargs)
File "/levanter/src/levanter/trainer.py", line 191, in fn
return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs))
File "/levanter/src/levanter/types.py", line 75, in __call__
return model.compute_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs)
File "/levanter/src/levanter/models/lm_model.py", line 129, in compute_loss
logits = self(example.tokens, example.attn_mask, key=key)
File "/levanter/src/levanter/models/gpt2.py", line 399, in __call__
x = self.transformer(x, attn_mask, key=k_transformer)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/gpt2.py", line 298, in __call__
x = self.blocks.fold(x, attn_mask, hax.arange(self.config.Layers), key=keys)
File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 221, in fold
return haliax.fold(do_block, self.Block)(init, self.stacked, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 202, in scanned_f
return scan_preconfig(init, *args, **kwargs)[0]
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 134, in scanned_f
carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 127, in wrapped_fn
carry, y = f(carry, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 197, in scan_compatible_fn
return fn(carry, *args, **kwargs), None
File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 75, in wrapper
dynamic_out, static_out = checkpointed_fun(static, dynamic)
File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 66, in _fn
_out = fun(*_args, **_kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 225, in _do_block
return block(carry, *extra_args, **extra_kwargs)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/gpt2.py", line 268, in __call__
attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/gpt2.py", line 200, in __call__
attn_output = dot_product_attention(
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/attention.py", line 119, in dot_product_attention
attention_out = _try_te_attention(
File "/levanter/src/levanter/models/attention.py", line 242, in _try_te_attention
return _te_flash_attention(
File "/levanter/src/levanter/models/attention.py", line 313, in _te_flash_attention
from transformer_engine.jax.fused_attn import fused_attn # noqa: F401
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/__init__.py", line 10, in <module>
import transformer_engine.common
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 107, in <module>
_TE_LIB_CTYPES = _load_library()
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 78, in _load_library
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
self._handle = _dlopen(self._name, mode)
OSError: /usr/local/lib/python3.10/dist-packages/transformer_engine/libtransformer_engine.so: undefined symbol: cudnnGetLastErrorString
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
levanter.config.main(main)()
File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
response = fn(cfg, *args, **kwargs)
File "/levanter/src/levanter/main/train_lm.py", line 215, in main
trainer.train(state, train_loader)
File "/levanter/src/levanter/trainer.py", line 403, in train
for info in self.training_steps(state, train_loader, run_hooks=run_hooks):
File "/levanter/src/levanter/trainer.py", line 386, in training_steps
info = self.train_step(state, example)
File "/levanter/src/levanter/trainer.py", line 370, in train_step
loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 261, in __call__
return self._call(False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/equinox/_module.py", line 1053, in __call__
return self.__func__(self.__self__, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 315, in _call
output_shape = _cached_filter_eval_shape(self._fn, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 546, in _cached_filter_eval_shape
_eval_shape_cache[static] = eqx.filter_eval_shape(fun, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/equinox/_eval_shape.py", line 38, in filter_eval_shape
dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
File "/levanter/src/levanter/trainer.py", line 498, in _train_step
loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)
File "/levanter/src/levanter/trainer.py", line 515, in _compute_gradients_microbatched
return grad_fn(model, *batch, **batch_kwargs)
File "/levanter/src/levanter/grad_accum.py", line 92, in wrapped_fn
r_shape = eqx.filter_eval_shape(fn, *args, **kwargs)
File "/levanter/src/levanter/trainer.py", line 191, in fn
return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs))
File "/levanter/src/levanter/types.py", line 75, in __call__
return model.compute_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs)
File "/levanter/src/levanter/models/lm_model.py", line 129, in compute_loss
logits = self(example.tokens, example.attn_mask, key=key)
File "/levanter/src/levanter/models/gpt2.py", line 399, in __call__
x = self.transformer(x, attn_mask, key=k_transformer)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/gpt2.py", line 298, in __call__
x = self.blocks.fold(x, attn_mask, hax.arange(self.config.Layers), key=keys)
File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 221, in fold
return haliax.fold(do_block, self.Block)(init, self.stacked, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 202, in scanned_f
return scan_preconfig(init, *args, **kwargs)[0]
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 134, in scanned_f
carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 127, in wrapped_fn
carry, y = f(carry, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 197, in scan_compatible_fn
return fn(carry, *args, **kwargs), None
File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 75, in wrapper
dynamic_out, static_out = checkpointed_fun(static, dynamic)
File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 66, in _fn
_out = fun(*_args, **_kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 225, in _do_block
return block(carry, *extra_args, **extra_kwargs)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/gpt2.py", line 268, in __call__
attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/gpt2.py", line 200, in __call__
attn_output = dot_product_attention(
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/attention.py", line 119, in dot_product_attention
attention_out = _try_te_attention(
File "/levanter/src/levanter/models/attention.py", line 242, in _try_te_attention
return _te_flash_attention(
File "/levanter/src/levanter/models/attention.py", line 313, in _te_flash_attention
from transformer_engine.jax.fused_attn import fused_attn # noqa: F401
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/__init__.py", line 10, in <module>
import transformer_engine.common
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 107, in <module>
_TE_LIB_CTYPES = _load_library()
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 78, in _load_library
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
self._handle = _dlopen(self._name, mode)
OSError: /usr/local/lib/python3.10/dist-packages/transformer_engine/libtransformer_engine.so: undefined symbol: cudnnGetLastErrorString
2024-07-25 20:44:28,532 WARNING worker.py:1450 -- SIGTERM handler is not set because current thread is not the main thread.
/usr/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = _posixsubprocess.fork_exec(
wandb:
wandb: Run history:
wandb: preprocessing//train/chunks ▁
wandb: preprocessing//train/finished ▁
wandb: preprocessing//train/input_ids ▁
wandb: preprocessing//train/rows ▁
wandb: preprocessing//train/shards ▁
wandb: preprocessing//train/token_type_ids ▁
wandb: preprocessing//validation/chunks ▁
wandb: preprocessing//validation/finished ▁
wandb: preprocessing//validation/input_ids ▁
wandb: preprocessing//validation/rows ▁
wandb: preprocessing//validation/shards ▁
wandb: preprocessing//validation/token_type_ids ▁
wandb:
wandb: Run summary:
wandb: backend gpu
wandb: num_devices 1
wandb: num_hosts 1
wandb: parameter_count 359708672
wandb: preprocessing//train/chunks 1
wandb: preprocessing//train/finished 1
wandb: preprocessing//train/input_ids 285696
wandb: preprocessing//train/rows 279
wandb: preprocessing//train/shards 1
wandb: preprocessing//train/token_type_ids 285696
wandb: preprocessing//validation/chunks 1
wandb: preprocessing//validation/finished 1
wandb: preprocessing//validation/input_ids 28672
wandb: preprocessing//validation/rows 28
wandb: preprocessing//validation/shards 1
wandb: preprocessing//validation/token_type_ids 28672
wandb: throughput/device_kind NVIDIA A10G
wandb: throughput/flops_per_example 2514493636608.0
wandb: throughput/theoretical_flops 125000000000000.0
wandb: throughput/theoretical_flops_per_device 125000000000000.0
wandb:
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /levanter/wandb/offline-run-20240725_204346-cirlyl1x
wandb: Find logs at: ./wandb/offline-run-20240725_204346-cirlyl1x/logs
wandb: WARNING The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require("core")`! See https://wandb.me/wandb-core for more information.
2024-07-25 20:44:30,777 INFO worker.py:1779 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265
2024-07-25T20:44:31 - 0 - ShardCache.cache/train - shard_cache.py:1418 - ERROR :: Error while reading from shard cache.
Traceback (most recent call last):
File "/levanter/src/levanter/data/shard_cache.py", line 1406, in iter_batches_from_chunks
chunk = self._get_chunk_unmapped(i)
File "/levanter/src/levanter/data/shard_cache.py", line 1336, in _get_chunk_unmapped
chunk = ray.get(self._broker.get_chunk.remote(mapped_index), timeout=current_timeout)
File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 202, in remote
return self._remote(args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/util/tracing/tracing_helper.py", line 426, in _start_span
return method(self, args, kwargs, *_args, **_kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 330, in _remote
return invocation(args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 311, in invocation
return actor._actor_method_call(
File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 1460, in _actor_method_call
object_refs = worker.core_worker.submit_actor_task(
File "python/ray/_raylet.pyx", line 4258, in ray._raylet.CoreWorker.submit_actor_task
File "python/ray/_raylet.pyx", line 4313, in ray._raylet.CoreWorker.submit_actor_task
Exception: Failed to submit task to actor ActorID(01757d23543161964c52264701000000) due to b"Can't find actor 01757d23543161964c52264701000000. It might be dead or it's from a different cluster"
Please ignore and let me retest - I wasn't on Levanter head so that might be it. Will follow up
Yeah same output unfortunately. Is there a way I can access older jax containers?
INFO:levanter.distributed:Not initializing jax.distributed because no distributed config was provided, and no cluster was detected.
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:levanter.trainer:Setting run id to oz48brcz
2024-07-25T21:34:46 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /levanter
2024-07-25T21:34:46 - 0 - wandb.sdk.lib.gitlib - gitlib.py:92 - ERROR :: git root error: Cmd('git') failed due to: exit code(128)
cmdline: git rev-parse --show-toplevel
stderr: 'fatal: detected dubious ownership in repository at '/levanter'
To add an exception for this directory, call:
git config --global --add safe.directory /levanter'
2024-07-25T21:34:46 - 0 - wandb.sdk.lib.gitlib - gitlib.py:92 - ERROR :: git root error: Cmd('git') failed due to: exit code(128)
cmdline: git rev-parse --show-toplevel
stderr: 'fatal: detected dubious ownership in repository at '/levanter'
To add an exception for this directory, call:
git config --global --add safe.directory /levanter'
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3
wandb: You chose "Don't visualize my results"
wandb: WARNING `resume` will be ignored since W&B syncing is set to `offline`. Starting a new run with run id oz48brcz.
wandb: Tracking run with wandb version 0.17.5
wandb: W&B syncing is set to `offline` in this directory.
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
2024-07-25T21:34:49 - 0 - levanter.distributed - distributed.py:215 - INFO :: No auto-discovered ray address found. Using ray.init('local').
2024-07-25T21:34:49 - 0 - levanter.distributed - distributed.py:267 - INFO :: ray.init(address='local', namespace='levanter', **{})
/usr/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = _posixsubprocess.fork_exec(
/usr/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = _posixsubprocess.fork_exec(
2024-07-25 21:34:51,416 INFO worker.py:1779 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265
2024-07-25T21:34:52 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /levanter
2024-07-25T21:34:52 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /levanter
train: 0%| | 0/50 [00:00<?, ?it/s]2024-07-25T21:34:53 - 0 - levanter.data.shard_cache - shard_cache.py:1266 - INFO :: Loading cache from /cache/validation
2024-07-25T21:34:53 - 0 - levanter.data.text - text.py:692 - INFO :: Building cache for validation...
2024-07-25T21:34:53 - 0 - levanter.data.shard_cache - shard_cache.py:1266 - INFO :: Loading cache from /cache/validation
(ChunkCacheBuilder pid=505) 2024-07-25 21:34:59,609 - levanter.data.shard_cache.builder::cache/validation - INFO - Starting cache build for 1 shards
2024-07-25T21:35:06 - 0 - levanter.data.text - text.py:256 - INFO :: Cache /cache/validation is complete.
2024-07-25T21:35:06 - 0 - levanter.data.shard_cache - shard_cache.py:1266 - INFO :: Loading cache from /cache/train
2024-07-25T21:35:06 - 0 - levanter.data.text - text.py:692 - INFO :: Building cache for train...
2024-07-25T21:35:06 - 0 - levanter.data.shard_cache - shard_cache.py:1266 - INFO :: Loading cache from /cache/train
2024-07-25T21:35:06 - 0 - preprocessing..validation - metrics_monitor.py:143 - INFO :: done: Shards: 0 | Chunks: 1 | Docs: 28
2024-07-25T21:35:06 - 0 - preprocessing..validation - metrics_monitor.py:143 - INFO :: done: Shards: 1 | Chunks: 1 | Docs: 28
2024-07-25T21:35:06 - 0 - preprocessing..validation - metrics_monitor.py:150 - INFO :: Cache creation finished
(ChunkCacheBroker pid=460) 2024-07-25 21:35:06,723 - levanter.data.shard_cache - INFO - Finalizing cache /cache/validation...
(ChunkCacheBuilder pid=505) 2024-07-25 21:35:06,718 - levanter.data.shard_cache - INFO - Shard valid_txt finished
2024-07-25T21:35:10 - 0 - levanter.data.text - text.py:258 - INFO :: Cache /cache/train is incomplete. This will block until at least one chunk per process is complete.
(ChunkCacheBuilder pid=699) 2024-07-25 21:35:13,905 - levanter.data.shard_cache.builder::cache/train - INFO - Starting cache build for 1 shards
2024-07-25T21:35:14 - 0 - __main__ - train_lm.py:129 - INFO :: No training checkpoint found. Initializing model from HF checkpoint 'stanford-crfm/music-medium-800k'
config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.96k/1.96k [00:00<00:00, 8.35MB/s]
config.json: 0%| | 0.00/1.96k [00:00<?, ?B/s2024-07-25T21:35:21 - 0 - preprocessing..train - metrics_monitor.py:143 - INFO :: done: Shards: 0 | Chunks: 1 | Docs: 279 | 262M/1.44G [00:05<00:23, 51.0MB/s]
2024-07-25T21:35:21 - 0 - preprocessing..train - metrics_monitor.py:143 - INFO :: done: Shards: 1 | Chunks: 1 | Docs: 279
2024-07-25T21:35:21 - 0 - preprocessing..train - metrics_monitor.py:143 - INFO :: done: Shards: 1 | Chunks: 1 | Docs: 279
2024-07-25T21:35:21 - 0 - preprocessing..train - metrics_monitor.py:150 - INFO :: Cache creation finished
(ChunkCacheBroker pid=664) 2024-07-25 21:35:21,188 - levanter.data.shard_cache - INFO - Finalizing cache /cache/train...
(ChunkCacheBuilder pid=699) 2024-07-25 21:35:21,183 - levanter.data.shard_cache - INFO - Shard train_txt finished
pytorch_model.bin: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.44G/1.44G [00:38<00:00, 37.6MB/s]
Loading weights: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 293/293 [00:01<00:00, 189.18it/s]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.<00:00, 182.51it/s]
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
levanter.config.main(main)()
File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
response = fn(cfg, *args, **kwargs)
File "/levanter/src/levanter/main/train_lm.py", line 215, in main
trainer.train(state, train_loader)
File "/levanter/src/levanter/trainer.py", line 403, in train
for info in self.training_steps(state, train_loader, run_hooks=run_hooks):
File "/levanter/src/levanter/trainer.py", line 386, in training_steps
info = self.train_step(state, example)
File "/levanter/src/levanter/trainer.py", line 370, in train_step
loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 261, in __call__
return self._call(False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/equinox/_module.py", line 1053, in __call__
return self.__func__(self.__self__, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 315, in _call
output_shape = _cached_filter_eval_shape(self._fn, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 546, in _cached_filter_eval_shape
_eval_shape_cache[static] = eqx.filter_eval_shape(fun, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/equinox/_eval_shape.py", line 38, in filter_eval_shape
dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
File "/levanter/src/levanter/trainer.py", line 498, in _train_step
loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)
File "/levanter/src/levanter/trainer.py", line 515, in _compute_gradients_microbatched
return grad_fn(model, *batch, **batch_kwargs)
File "/levanter/src/levanter/grad_accum.py", line 92, in wrapped_fn
r_shape = eqx.filter_eval_shape(fn, *args, **kwargs)
File "/levanter/src/levanter/trainer.py", line 191, in fn
return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs))
File "/levanter/src/levanter/types.py", line 75, in __call__
return model.compute_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs)
File "/levanter/src/levanter/models/lm_model.py", line 129, in compute_loss
logits = self(example.tokens, example.attn_mask, key=key)
File "/levanter/src/levanter/models/gpt2.py", line 399, in __call__
x = self.transformer(x, attn_mask, key=k_transformer)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/gpt2.py", line 298, in __call__
x = self.blocks.fold(x, attn_mask, hax.arange(self.config.Layers), key=keys)
File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 221, in fold
return haliax.fold(do_block, self.Block)(init, self.stacked, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 202, in scanned_f
return scan_preconfig(init, *args, **kwargs)[0]
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 134, in scanned_f
carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 127, in wrapped_fn
carry, y = f(carry, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 197, in scan_compatible_fn
return fn(carry, *args, **kwargs), None
File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 75, in wrapper
dynamic_out, static_out = checkpointed_fun(static, dynamic)
File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 66, in _fn
_out = fun(*_args, **_kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 225, in _do_block
return block(carry, *extra_args, **extra_kwargs)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/gpt2.py", line 268, in __call__
attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/gpt2.py", line 200, in __call__
attn_output = dot_product_attention(
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/attention.py", line 119, in dot_product_attention
attention_out = _try_te_attention(
File "/levanter/src/levanter/models/attention.py", line 242, in _try_te_attention
return _te_flash_attention(
File "/levanter/src/levanter/models/attention.py", line 313, in _te_flash_attention
from transformer_engine.jax.fused_attn import fused_attn # noqa: F401
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/__init__.py", line 10, in <module>
import transformer_engine.common
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 107, in <module>
_TE_LIB_CTYPES = _load_library()
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 78, in _load_library
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
self._handle = _dlopen(self._name, mode)
OSError: /usr/local/lib/python3.10/dist-packages/transformer_engine/libtransformer_engine.so: undefined symbol: cudnnGetLastErrorString
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
levanter.config.main(main)()
File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
response = fn(cfg, *args, **kwargs)
File "/levanter/src/levanter/main/train_lm.py", line 215, in main
trainer.train(state, train_loader)
File "/levanter/src/levanter/trainer.py", line 403, in train
for info in self.training_steps(state, train_loader, run_hooks=run_hooks):
File "/levanter/src/levanter/trainer.py", line 386, in training_steps
info = self.train_step(state, example)
File "/levanter/src/levanter/trainer.py", line 370, in train_step
loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 261, in __call__
return self._call(False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/equinox/_module.py", line 1053, in __call__
return self.__func__(self.__self__, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 315, in _call
output_shape = _cached_filter_eval_shape(self._fn, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 546, in _cached_filter_eval_shape
_eval_shape_cache[static] = eqx.filter_eval_shape(fun, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/equinox/_eval_shape.py", line 38, in filter_eval_shape
dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
File "/levanter/src/levanter/trainer.py", line 498, in _train_step
loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)
File "/levanter/src/levanter/trainer.py", line 515, in _compute_gradients_microbatched
return grad_fn(model, *batch, **batch_kwargs)
File "/levanter/src/levanter/grad_accum.py", line 92, in wrapped_fn
r_shape = eqx.filter_eval_shape(fn, *args, **kwargs)
File "/levanter/src/levanter/trainer.py", line 191, in fn
return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs))
File "/levanter/src/levanter/types.py", line 75, in __call__
return model.compute_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs)
File "/levanter/src/levanter/models/lm_model.py", line 129, in compute_loss
logits = self(example.tokens, example.attn_mask, key=key)
File "/levanter/src/levanter/models/gpt2.py", line 399, in __call__
x = self.transformer(x, attn_mask, key=k_transformer)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/gpt2.py", line 298, in __call__
x = self.blocks.fold(x, attn_mask, hax.arange(self.config.Layers), key=keys)
File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 221, in fold
return haliax.fold(do_block, self.Block)(init, self.stacked, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 202, in scanned_f
return scan_preconfig(init, *args, **kwargs)[0]
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 134, in scanned_f
carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 127, in wrapped_fn
carry, y = f(carry, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 197, in scan_compatible_fn
return fn(carry, *args, **kwargs), None
File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 75, in wrapper
dynamic_out, static_out = checkpointed_fun(static, dynamic)
File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 66, in _fn
_out = fun(*_args, **_kwargs)
File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 225, in _do_block
return block(carry, *extra_args, **extra_kwargs)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/gpt2.py", line 268, in __call__
attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/gpt2.py", line 200, in __call__
attn_output = dot_product_attention(
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/levanter/src/levanter/models/attention.py", line 119, in dot_product_attention
attention_out = _try_te_attention(
File "/levanter/src/levanter/models/attention.py", line 242, in _try_te_attention
return _te_flash_attention(
File "/levanter/src/levanter/models/attention.py", line 313, in _te_flash_attention
from transformer_engine.jax.fused_attn import fused_attn # noqa: F401
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/__init__.py", line 10, in <module>
import transformer_engine.common
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 107, in <module>
_TE_LIB_CTYPES = _load_library()
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 78, in _load_library
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
self._handle = _dlopen(self._name, mode)
OSError: /usr/local/lib/python3.10/dist-packages/transformer_engine/libtransformer_engine.so: undefined symbol: cudnnGetLastErrorString
2024-07-25 21:36:00,899 WARNING worker.py:1450 -- SIGTERM handler is not set because current thread is not the main thread.
/usr/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = _posixsubprocess.fork_exec(
wandb:
wandb: Run history:
wandb: preprocessing//train/chunks ▁
wandb: preprocessing//train/finished ▁
wandb: preprocessing//train/input_ids ▁
wandb: preprocessing//train/rows ▁
wandb: preprocessing//train/shards ▁
wandb: preprocessing//train/token_type_ids ▁
wandb: preprocessing//validation/chunks ▁
wandb: preprocessing//validation/finished ▁
wandb: preprocessing//validation/input_ids ▁
wandb: preprocessing//validation/rows ▁
wandb: preprocessing//validation/shards ▁
wandb: preprocessing//validation/token_type_ids ▁
wandb:
wandb: Run summary:
wandb: backend gpu
wandb: num_devices 1
wandb: num_hosts 1
wandb: parameter_count 359708672
wandb: preprocessing//train/chunks 1
wandb: preprocessing//train/finished 1
wandb: preprocessing//train/input_ids 285696
wandb: preprocessing//train/rows 279
wandb: preprocessing//train/shards 1
wandb: preprocessing//train/token_type_ids 285696
wandb: preprocessing//validation/chunks 1
wandb: preprocessing//validation/finished 1
wandb: preprocessing//validation/input_ids 28672
wandb: preprocessing//validation/rows 28
wandb: preprocessing//validation/shards 1
wandb: preprocessing//validation/token_type_ids 28672
wandb: throughput/device_kind NVIDIA A10G
wandb: throughput/flops_per_example 2514493636608.0
wandb: throughput/theoretical_flops 125000000000000.0
wandb: throughput/theoretical_flops_per_device 125000000000000.0
wandb:
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /levanter/wandb/offline-run-20240725_213448-oz48brcz
wandb: Find logs at: ./wandb/offline-run-20240725_213448-oz48brcz/logs
wandb: WARNING The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require("core")`! See https://wandb.me/wandb-core for more information.
2024-07-25 21:36:03,171 INFO worker.py:1779 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265
2024-07-25T21:36:04 - 0 - ShardCache.cache/train - shard_cache.py:1418 - ERROR :: Error while reading from shard cache.
Traceback (most recent call last):
File "/levanter/src/levanter/data/shard_cache.py", line 1406, in iter_batches_from_chunks
chunk = self._get_chunk_unmapped(i)
File "/levanter/src/levanter/data/shard_cache.py", line 1336, in _get_chunk_unmapped
chunk = ray.get(self._broker.get_chunk.remote(mapped_index), timeout=current_timeout)
File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 202, in remote
return self._remote(args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/util/tracing/tracing_helper.py", line 426, in _start_span
return method(self, args, kwargs, *_args, **_kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 330, in _remote
return invocation(args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 311, in invocation
return actor._actor_method_call(
File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 1460, in _actor_method_call
object_refs = worker.core_worker.submit_actor_task(
File "python/ray/_raylet.pyx", line 4258, in ray._raylet.CoreWorker.submit_actor_task
File "python/ray/_raylet.pyx", line 4313, in ray._raylet.CoreWorker.submit_actor_task
Exception: Failed to submit task to actor ActorID(0e2b26ec3dd9161148784ee601000000) due to b"Can't find actor 0e2b26ec3dd9161148784ee601000000. It might be dead or it's from a different cluster"
Exception in thread ray_print_logs:
Traceback (most recent call last):
File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/usr/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 893, in print_logs
subscriber.subscribe()
File "python/ray/_raylet.pyx", line 3111, in ray._raylet._GcsSubscriber.subscribe
File "python/ray/_raylet.pyx", line 586, in ray._raylet.check_status
ray.exceptions.RpcError: recvmsg:Connection reset by peer
this seems like a probelm with the cuda stuff. @DwarKapex Any thoughts?
(You can probably use https://github.com/orgs/nvidia/packages/container/jax/248105500?tag=levanter-2024-07-24 as jax:levanter-2024-07-24)
that did the trick! Thanks so much!
If it is of any help to you here is a slightly more detailed description of what I am trying to do.
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
levanter.config.main(main)()
File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
response = fn(cfg, *args, **kwargs)
File "/levanter/src/levanter/main/train_lm.py", line 82, in main
levanter.initialize(config)
File "/levanter/src/levanter/trainer.py", line 796, in initialize
trainer_config.initialize()
File "/levanter/src/levanter/trainer.py", line 627, in initialize
_initialize_global_tracker(self.tracker, id)
File "/levanter/src/levanter/trainer.py", line 522, in _initialize_global_tracker
tracker = config.init(run_id)
File "/levanter/src/levanter/tracker/wandb.py", line 131, in init
git_settings = self._git_settings()
File "/levanter/src/levanter/tracker/wandb.py", line 204, in _git_settings
sha = self._get_git_sha(code_dir)
File "/levanter/src/levanter/tracker/wandb.py", line 216, in _get_git_sha
git_sha = repo.head.commit.hexsha
File "/usr/local/lib/python3.10/dist-packages/git/refs/symbolic.py", line 297, in _get_commit
obj = self._get_object()
File "/usr/local/lib/python3.10/dist-packages/git/refs/symbolic.py", line 288, in _get_object
return Object.new_from_sha(self.repo, hex_to_bin(self.dereference_recursive(self.repo, self.path)))
File "/usr/local/lib/python3.10/dist-packages/git/objects/base.py", line 149, in new_from_sha
oinfo = repo.odb.info(sha1)
File "/usr/local/lib/python3.10/dist-packages/git/db.py", line 41, in info
hexsha, typename, size = self._git.get_object_header(bin_to_hex(binsha))
File "/usr/local/lib/python3.10/dist-packages/git/cmd.py", line 1678, in get_object_header
return self.__get_object_header(cmd, ref)
File "/usr/local/lib/python3.10/dist-packages/git/cmd.py", line 1661, in __get_object_header
cmd.stdin.flush()
BrokenPipeError: [Errno 32] Broken pipe
Sharing this in case it helps you guys with debugging. Feel free to follow up if you have any question.
weird. that's easy to workaround but I don't know why it's happening
Hi - after yesterday's code update I am getting the following error. Any advise? I am using the Jax-levanter docker image
EDIT: Actually I now see the main error with previous repo HEAD, which is weird cause yesterday the container was working fine. You think something has change on the NVIDA-image ?