stanford-crfm / levanter

Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax
https://levanter.readthedocs.io/en/latest/
Apache License 2.0
489 stars 78 forks source link

ModuleNotFoundError: No module named 'jax.experimental.maps' after main-brunch update (#662) #673

Open MikeMpapa opened 1 month ago

MikeMpapa commented 1 month ago

Hi - after yesterday's code update I am getting the following error. Any advise? I am using the Jax-levanter docker image

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 119, in main
    Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 597, in round_axis_for_partitioning
    size = physical_axis_size(axis, mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 566, in physical_axis_size
    mesh = _get_mesh()
  File "/opt/haliax/src/haliax/partitioning.py", line 606, in _get_mesh
    from jax.experimental.maps import thread_resources
ModuleNotFoundError: No module named 'jax.experimental.maps'

EDIT: Actually I now see the main error with previous repo HEAD, which is weird cause yesterday the container was working fine. You think something has change on the NVIDA-image ?

dlwh commented 1 month ago

yeah looks like they moved/removed jax.experimental.maps. the container sticks close to JAX head, which we don't.

dlwh commented 1 month ago

Once I merge https://github.com/stanford-crfm/haliax/pull/102 and a suitable interval for the package to propagate, you should be able to update to the latest dev version of haliax (probably 308 or 309)

dlwh commented 1 month ago

try pip install haliax==1.4.dev310 and see if it fixes

MikeMpapa commented 1 month ago

Thanks for your response.

This caused a series of errors that is hard for me trace exactly so I copy the full output for now:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 215, in main
    trainer.train(state, train_loader)
  File "/levanter/src/levanter/trainer.py", line 403, in train
    for info in self.training_steps(state, train_loader, run_hooks=run_hooks):
  File "/levanter/src/levanter/trainer.py", line 386, in training_steps
    info = self.train_step(state, example)
  File "/levanter/src/levanter/trainer.py", line 370, in train_step
    loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 261, in __call__
    return self._call(False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/equinox/_module.py", line 1053, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 315, in _call
    output_shape = _cached_filter_eval_shape(self._fn, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 546, in _cached_filter_eval_shape
    _eval_shape_cache[static] = eqx.filter_eval_shape(fun, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/equinox/_eval_shape.py", line 38, in filter_eval_shape
    dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
  File "/levanter/src/levanter/trainer.py", line 498, in _train_step
    loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)
  File "/levanter/src/levanter/trainer.py", line 515, in _compute_gradients_microbatched
    return grad_fn(model, *batch, **batch_kwargs)
  File "/levanter/src/levanter/grad_accum.py", line 92, in wrapped_fn
    r_shape = eqx.filter_eval_shape(fn, *args, **kwargs)
  File "/levanter/src/levanter/trainer.py", line 191, in fn
    return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs))
  File "/levanter/src/levanter/types.py", line 75, in __call__
    return model.compute_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs)
  File "/levanter/src/levanter/models/lm_model.py", line 129, in compute_loss
    logits = self(example.tokens, example.attn_mask, key=key)
  File "/levanter/src/levanter/models/gpt2.py", line 399, in __call__
    x = self.transformer(x, attn_mask, key=k_transformer)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 298, in __call__
    x = self.blocks.fold(x, attn_mask, hax.arange(self.config.Layers), key=keys)
  File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 221, in fold
    return haliax.fold(do_block, self.Block)(init, self.stacked, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 202, in scanned_f
    return scan_preconfig(init, *args, **kwargs)[0]
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 134, in scanned_f
    carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 127, in wrapped_fn
    carry, y = f(carry, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 197, in scan_compatible_fn
    return fn(carry, *args, **kwargs), None
  File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 75, in wrapper
    dynamic_out, static_out = checkpointed_fun(static, dynamic)
  File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 66, in _fn
    _out = fun(*_args, **_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 225, in _do_block
    return block(carry, *extra_args, **extra_kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 268, in __call__
    attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 200, in __call__
    attn_output = dot_product_attention(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/attention.py", line 119, in dot_product_attention
    attention_out = _try_te_attention(
  File "/levanter/src/levanter/models/attention.py", line 242, in _try_te_attention
    return _te_flash_attention(
  File "/levanter/src/levanter/models/attention.py", line 313, in _te_flash_attention
    from transformer_engine.jax.fused_attn import fused_attn  # noqa: F401
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/__init__.py", line 10, in <module>
    import transformer_engine.common
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 107, in <module>
    _TE_LIB_CTYPES = _load_library()
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 78, in _load_library
    return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
  File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /usr/local/lib/python3.10/dist-packages/transformer_engine/libtransformer_engine.so: undefined symbol: cudnnGetLastErrorString
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 215, in main
    trainer.train(state, train_loader)
  File "/levanter/src/levanter/trainer.py", line 403, in train
    for info in self.training_steps(state, train_loader, run_hooks=run_hooks):
  File "/levanter/src/levanter/trainer.py", line 386, in training_steps
    info = self.train_step(state, example)
  File "/levanter/src/levanter/trainer.py", line 370, in train_step
    loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 261, in __call__
    return self._call(False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/equinox/_module.py", line 1053, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 315, in _call
    output_shape = _cached_filter_eval_shape(self._fn, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 546, in _cached_filter_eval_shape
    _eval_shape_cache[static] = eqx.filter_eval_shape(fun, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/equinox/_eval_shape.py", line 38, in filter_eval_shape
    dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
  File "/levanter/src/levanter/trainer.py", line 498, in _train_step
    loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)
  File "/levanter/src/levanter/trainer.py", line 515, in _compute_gradients_microbatched
    return grad_fn(model, *batch, **batch_kwargs)
  File "/levanter/src/levanter/grad_accum.py", line 92, in wrapped_fn
    r_shape = eqx.filter_eval_shape(fn, *args, **kwargs)
  File "/levanter/src/levanter/trainer.py", line 191, in fn
    return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs))
  File "/levanter/src/levanter/types.py", line 75, in __call__
    return model.compute_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs)
  File "/levanter/src/levanter/models/lm_model.py", line 129, in compute_loss
    logits = self(example.tokens, example.attn_mask, key=key)
  File "/levanter/src/levanter/models/gpt2.py", line 399, in __call__
    x = self.transformer(x, attn_mask, key=k_transformer)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 298, in __call__
    x = self.blocks.fold(x, attn_mask, hax.arange(self.config.Layers), key=keys)
  File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 221, in fold
    return haliax.fold(do_block, self.Block)(init, self.stacked, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 202, in scanned_f
    return scan_preconfig(init, *args, **kwargs)[0]
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 134, in scanned_f
    carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 127, in wrapped_fn
    carry, y = f(carry, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 197, in scan_compatible_fn
    return fn(carry, *args, **kwargs), None
  File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 75, in wrapper
    dynamic_out, static_out = checkpointed_fun(static, dynamic)
  File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 66, in _fn
    _out = fun(*_args, **_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 225, in _do_block
    return block(carry, *extra_args, **extra_kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 268, in __call__
    attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 200, in __call__
    attn_output = dot_product_attention(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/attention.py", line 119, in dot_product_attention
    attention_out = _try_te_attention(
  File "/levanter/src/levanter/models/attention.py", line 242, in _try_te_attention
    return _te_flash_attention(
  File "/levanter/src/levanter/models/attention.py", line 313, in _te_flash_attention
    from transformer_engine.jax.fused_attn import fused_attn  # noqa: F401
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/__init__.py", line 10, in <module>
    import transformer_engine.common
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 107, in <module>
    _TE_LIB_CTYPES = _load_library()
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 78, in _load_library
    return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
  File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /usr/local/lib/python3.10/dist-packages/transformer_engine/libtransformer_engine.so: undefined symbol: cudnnGetLastErrorString
2024-07-25 20:44:28,532 WARNING worker.py:1450 -- SIGTERM handler is not set because current thread is not the main thread.
/usr/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _posixsubprocess.fork_exec(
wandb: 
wandb: Run history:
wandb:              preprocessing//train/chunks ▁
wandb:            preprocessing//train/finished ▁
wandb:           preprocessing//train/input_ids ▁
wandb:                preprocessing//train/rows ▁
wandb:              preprocessing//train/shards ▁
wandb:      preprocessing//train/token_type_ids ▁
wandb:         preprocessing//validation/chunks ▁
wandb:       preprocessing//validation/finished ▁
wandb:      preprocessing//validation/input_ids ▁
wandb:           preprocessing//validation/rows ▁
wandb:         preprocessing//validation/shards ▁
wandb: preprocessing//validation/token_type_ids ▁
wandb: 
wandb: Run summary:
wandb:                                  backend gpu
wandb:                              num_devices 1
wandb:                                num_hosts 1
wandb:                          parameter_count 359708672
wandb:              preprocessing//train/chunks 1
wandb:            preprocessing//train/finished 1
wandb:           preprocessing//train/input_ids 285696
wandb:                preprocessing//train/rows 279
wandb:              preprocessing//train/shards 1
wandb:      preprocessing//train/token_type_ids 285696
wandb:         preprocessing//validation/chunks 1
wandb:       preprocessing//validation/finished 1
wandb:      preprocessing//validation/input_ids 28672
wandb:           preprocessing//validation/rows 28
wandb:         preprocessing//validation/shards 1
wandb: preprocessing//validation/token_type_ids 28672
wandb:                   throughput/device_kind NVIDIA A10G
wandb:             throughput/flops_per_example 2514493636608.0
wandb:             throughput/theoretical_flops 125000000000000.0
wandb:  throughput/theoretical_flops_per_device 125000000000000.0
wandb: 
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /levanter/wandb/offline-run-20240725_204346-cirlyl1x
wandb: Find logs at: ./wandb/offline-run-20240725_204346-cirlyl1x/logs
wandb: WARNING The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require("core")`! See https://wandb.me/wandb-core for more information.
2024-07-25 20:44:30,777 INFO worker.py:1779 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
2024-07-25T20:44:31 - 0 - ShardCache.cache/train - shard_cache.py:1418 - ERROR :: Error while reading from shard cache.
Traceback (most recent call last):
  File "/levanter/src/levanter/data/shard_cache.py", line 1406, in iter_batches_from_chunks
    chunk = self._get_chunk_unmapped(i)
  File "/levanter/src/levanter/data/shard_cache.py", line 1336, in _get_chunk_unmapped
    chunk = ray.get(self._broker.get_chunk.remote(mapped_index), timeout=current_timeout)
  File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 202, in remote
    return self._remote(args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/util/tracing/tracing_helper.py", line 426, in _start_span
    return method(self, args, kwargs, *_args, **_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 330, in _remote
    return invocation(args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 311, in invocation
    return actor._actor_method_call(
  File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 1460, in _actor_method_call
    object_refs = worker.core_worker.submit_actor_task(
  File "python/ray/_raylet.pyx", line 4258, in ray._raylet.CoreWorker.submit_actor_task
  File "python/ray/_raylet.pyx", line 4313, in ray._raylet.CoreWorker.submit_actor_task
Exception: Failed to submit task to actor ActorID(01757d23543161964c52264701000000) due to b"Can't find actor 01757d23543161964c52264701000000. It might be dead or it's from a different cluster"
MikeMpapa commented 1 month ago

Please ignore and let me retest - I wasn't on Levanter head so that might be it. Will follow up

MikeMpapa commented 1 month ago

Yeah same output unfortunately. Is there a way I can access older jax containers?

INFO:levanter.distributed:Not initializing jax.distributed because no distributed config was provided, and no cluster was detected.
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:levanter.trainer:Setting run id to oz48brcz
2024-07-25T21:34:46 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /levanter
2024-07-25T21:34:46 - 0 - wandb.sdk.lib.gitlib - gitlib.py:92 - ERROR :: git root error: Cmd('git') failed due to: exit code(128)
  cmdline: git rev-parse --show-toplevel
  stderr: 'fatal: detected dubious ownership in repository at '/levanter'
To add an exception for this directory, call:

        git config --global --add safe.directory /levanter'
2024-07-25T21:34:46 - 0 - wandb.sdk.lib.gitlib - gitlib.py:92 - ERROR :: git root error: Cmd('git') failed due to: exit code(128)
  cmdline: git rev-parse --show-toplevel
  stderr: 'fatal: detected dubious ownership in repository at '/levanter'
To add an exception for this directory, call:

        git config --global --add safe.directory /levanter'
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3
wandb: You chose "Don't visualize my results"
wandb: WARNING `resume` will be ignored since W&B syncing is set to `offline`. Starting a new run with run id oz48brcz.
wandb: Tracking run with wandb version 0.17.5
wandb: W&B syncing is set to `offline` in this directory.  
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
2024-07-25T21:34:49 - 0 - levanter.distributed - distributed.py:215 - INFO :: No auto-discovered ray address found. Using ray.init('local').
2024-07-25T21:34:49 - 0 - levanter.distributed - distributed.py:267 - INFO :: ray.init(address='local', namespace='levanter', **{})
/usr/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _posixsubprocess.fork_exec(
/usr/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _posixsubprocess.fork_exec(
2024-07-25 21:34:51,416 INFO worker.py:1779 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
2024-07-25T21:34:52 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /levanter
2024-07-25T21:34:52 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /levanter
train:   0%|                                                                                                                                                                    | 0/50 [00:00<?, ?it/s]2024-07-25T21:34:53 - 0 - levanter.data.shard_cache - shard_cache.py:1266 - INFO :: Loading cache from /cache/validation
2024-07-25T21:34:53 - 0 - levanter.data.text - text.py:692 - INFO :: Building cache for validation...
2024-07-25T21:34:53 - 0 - levanter.data.shard_cache - shard_cache.py:1266 - INFO :: Loading cache from /cache/validation
(ChunkCacheBuilder pid=505) 2024-07-25 21:34:59,609 - levanter.data.shard_cache.builder::cache/validation - INFO - Starting cache build for 1 shards
2024-07-25T21:35:06 - 0 - levanter.data.text - text.py:256 - INFO :: Cache /cache/validation is complete.
2024-07-25T21:35:06 - 0 - levanter.data.shard_cache - shard_cache.py:1266 - INFO :: Loading cache from /cache/train
2024-07-25T21:35:06 - 0 - levanter.data.text - text.py:692 - INFO :: Building cache for train...
2024-07-25T21:35:06 - 0 - levanter.data.shard_cache - shard_cache.py:1266 - INFO :: Loading cache from /cache/train
2024-07-25T21:35:06 - 0 - preprocessing..validation - metrics_monitor.py:143 - INFO ::  done: Shards: 0 | Chunks: 1 | Docs: 28
2024-07-25T21:35:06 - 0 - preprocessing..validation - metrics_monitor.py:143 - INFO ::  done: Shards: 1 | Chunks: 1 | Docs: 28
2024-07-25T21:35:06 - 0 - preprocessing..validation - metrics_monitor.py:150 - INFO :: Cache creation finished
(ChunkCacheBroker pid=460) 2024-07-25 21:35:06,723 - levanter.data.shard_cache - INFO - Finalizing cache /cache/validation...
(ChunkCacheBuilder pid=505) 2024-07-25 21:35:06,718 - levanter.data.shard_cache - INFO - Shard valid_txt finished
2024-07-25T21:35:10 - 0 - levanter.data.text - text.py:258 - INFO :: Cache /cache/train is incomplete. This will block until at least one chunk per process is complete.
(ChunkCacheBuilder pid=699) 2024-07-25 21:35:13,905 - levanter.data.shard_cache.builder::cache/train - INFO - Starting cache build for 1 shards
2024-07-25T21:35:14 - 0 - __main__ - train_lm.py:129 - INFO :: No training checkpoint found. Initializing model from HF checkpoint 'stanford-crfm/music-medium-800k'
config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.96k/1.96k [00:00<00:00, 8.35MB/s]
config.json:   0%|                                                                                                                                                         | 0.00/1.96k [00:00<?, ?B/s2024-07-25T21:35:21 - 0 - preprocessing..train - metrics_monitor.py:143 - INFO ::  done: Shards: 0 | Chunks: 1 | Docs: 279                                          | 262M/1.44G [00:05<00:23, 51.0MB/s]
2024-07-25T21:35:21 - 0 - preprocessing..train - metrics_monitor.py:143 - INFO ::  done: Shards: 1 | Chunks: 1 | Docs: 279
2024-07-25T21:35:21 - 0 - preprocessing..train - metrics_monitor.py:143 - INFO ::  done: Shards: 1 | Chunks: 1 | Docs: 279
2024-07-25T21:35:21 - 0 - preprocessing..train - metrics_monitor.py:150 - INFO :: Cache creation finished
(ChunkCacheBroker pid=664) 2024-07-25 21:35:21,188 - levanter.data.shard_cache - INFO - Finalizing cache /cache/train...
(ChunkCacheBuilder pid=699) 2024-07-25 21:35:21,183 - levanter.data.shard_cache - INFO - Shard train_txt finished
pytorch_model.bin: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.44G/1.44G [00:38<00:00, 37.6MB/s]
Loading weights: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 293/293 [00:01<00:00, 189.18it/s]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.<00:00, 182.51it/s]

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 215, in main
    trainer.train(state, train_loader)
  File "/levanter/src/levanter/trainer.py", line 403, in train
    for info in self.training_steps(state, train_loader, run_hooks=run_hooks):
  File "/levanter/src/levanter/trainer.py", line 386, in training_steps
    info = self.train_step(state, example)
  File "/levanter/src/levanter/trainer.py", line 370, in train_step
    loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 261, in __call__
    return self._call(False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/equinox/_module.py", line 1053, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 315, in _call
    output_shape = _cached_filter_eval_shape(self._fn, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 546, in _cached_filter_eval_shape
    _eval_shape_cache[static] = eqx.filter_eval_shape(fun, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/equinox/_eval_shape.py", line 38, in filter_eval_shape
    dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
  File "/levanter/src/levanter/trainer.py", line 498, in _train_step
    loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)
  File "/levanter/src/levanter/trainer.py", line 515, in _compute_gradients_microbatched
    return grad_fn(model, *batch, **batch_kwargs)
  File "/levanter/src/levanter/grad_accum.py", line 92, in wrapped_fn
    r_shape = eqx.filter_eval_shape(fn, *args, **kwargs)
  File "/levanter/src/levanter/trainer.py", line 191, in fn
    return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs))
  File "/levanter/src/levanter/types.py", line 75, in __call__
    return model.compute_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs)
  File "/levanter/src/levanter/models/lm_model.py", line 129, in compute_loss
    logits = self(example.tokens, example.attn_mask, key=key)
  File "/levanter/src/levanter/models/gpt2.py", line 399, in __call__
    x = self.transformer(x, attn_mask, key=k_transformer)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 298, in __call__
    x = self.blocks.fold(x, attn_mask, hax.arange(self.config.Layers), key=keys)
  File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 221, in fold
    return haliax.fold(do_block, self.Block)(init, self.stacked, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 202, in scanned_f
    return scan_preconfig(init, *args, **kwargs)[0]
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 134, in scanned_f
    carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 127, in wrapped_fn
    carry, y = f(carry, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 197, in scan_compatible_fn
    return fn(carry, *args, **kwargs), None
  File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 75, in wrapper
    dynamic_out, static_out = checkpointed_fun(static, dynamic)
  File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 66, in _fn
    _out = fun(*_args, **_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 225, in _do_block
    return block(carry, *extra_args, **extra_kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 268, in __call__
    attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 200, in __call__
    attn_output = dot_product_attention(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/attention.py", line 119, in dot_product_attention
    attention_out = _try_te_attention(
  File "/levanter/src/levanter/models/attention.py", line 242, in _try_te_attention
    return _te_flash_attention(
  File "/levanter/src/levanter/models/attention.py", line 313, in _te_flash_attention
    from transformer_engine.jax.fused_attn import fused_attn  # noqa: F401
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/__init__.py", line 10, in <module>
    import transformer_engine.common
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 107, in <module>
    _TE_LIB_CTYPES = _load_library()
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 78, in _load_library
    return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
  File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /usr/local/lib/python3.10/dist-packages/transformer_engine/libtransformer_engine.so: undefined symbol: cudnnGetLastErrorString
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 215, in main
    trainer.train(state, train_loader)
  File "/levanter/src/levanter/trainer.py", line 403, in train
    for info in self.training_steps(state, train_loader, run_hooks=run_hooks):
  File "/levanter/src/levanter/trainer.py", line 386, in training_steps
    info = self.train_step(state, example)
  File "/levanter/src/levanter/trainer.py", line 370, in train_step
    loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 261, in __call__
    return self._call(False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/equinox/_module.py", line 1053, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 315, in _call
    output_shape = _cached_filter_eval_shape(self._fn, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 546, in _cached_filter_eval_shape
    _eval_shape_cache[static] = eqx.filter_eval_shape(fun, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/equinox/_eval_shape.py", line 38, in filter_eval_shape
    dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
  File "/levanter/src/levanter/trainer.py", line 498, in _train_step
    loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)
  File "/levanter/src/levanter/trainer.py", line 515, in _compute_gradients_microbatched
    return grad_fn(model, *batch, **batch_kwargs)
  File "/levanter/src/levanter/grad_accum.py", line 92, in wrapped_fn
    r_shape = eqx.filter_eval_shape(fn, *args, **kwargs)
  File "/levanter/src/levanter/trainer.py", line 191, in fn
    return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs))
  File "/levanter/src/levanter/types.py", line 75, in __call__
    return model.compute_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs)
  File "/levanter/src/levanter/models/lm_model.py", line 129, in compute_loss
    logits = self(example.tokens, example.attn_mask, key=key)
  File "/levanter/src/levanter/models/gpt2.py", line 399, in __call__
    x = self.transformer(x, attn_mask, key=k_transformer)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 298, in __call__
    x = self.blocks.fold(x, attn_mask, hax.arange(self.config.Layers), key=keys)
  File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 221, in fold
    return haliax.fold(do_block, self.Block)(init, self.stacked, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 202, in scanned_f
    return scan_preconfig(init, *args, **kwargs)[0]
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 134, in scanned_f
    carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 127, in wrapped_fn
    carry, y = f(carry, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 197, in scan_compatible_fn
    return fn(carry, *args, **kwargs), None
  File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 75, in wrapper
    dynamic_out, static_out = checkpointed_fun(static, dynamic)
  File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 66, in _fn
    _out = fun(*_args, **_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 225, in _do_block
    return block(carry, *extra_args, **extra_kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 268, in __call__
    attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 200, in __call__
    attn_output = dot_product_attention(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/attention.py", line 119, in dot_product_attention
    attention_out = _try_te_attention(
  File "/levanter/src/levanter/models/attention.py", line 242, in _try_te_attention
    return _te_flash_attention(
  File "/levanter/src/levanter/models/attention.py", line 313, in _te_flash_attention
    from transformer_engine.jax.fused_attn import fused_attn  # noqa: F401
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/__init__.py", line 10, in <module>
    import transformer_engine.common
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 107, in <module>
    _TE_LIB_CTYPES = _load_library()
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 78, in _load_library
    return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
  File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /usr/local/lib/python3.10/dist-packages/transformer_engine/libtransformer_engine.so: undefined symbol: cudnnGetLastErrorString
2024-07-25 21:36:00,899 WARNING worker.py:1450 -- SIGTERM handler is not set because current thread is not the main thread.
/usr/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _posixsubprocess.fork_exec(
wandb: 
wandb: Run history:
wandb:              preprocessing//train/chunks ▁
wandb:            preprocessing//train/finished ▁
wandb:           preprocessing//train/input_ids ▁
wandb:                preprocessing//train/rows ▁
wandb:              preprocessing//train/shards ▁
wandb:      preprocessing//train/token_type_ids ▁
wandb:         preprocessing//validation/chunks ▁
wandb:       preprocessing//validation/finished ▁
wandb:      preprocessing//validation/input_ids ▁
wandb:           preprocessing//validation/rows ▁
wandb:         preprocessing//validation/shards ▁
wandb: preprocessing//validation/token_type_ids ▁
wandb: 
wandb: Run summary:
wandb:                                  backend gpu
wandb:                              num_devices 1
wandb:                                num_hosts 1
wandb:                          parameter_count 359708672
wandb:              preprocessing//train/chunks 1
wandb:            preprocessing//train/finished 1
wandb:           preprocessing//train/input_ids 285696
wandb:                preprocessing//train/rows 279
wandb:              preprocessing//train/shards 1
wandb:      preprocessing//train/token_type_ids 285696
wandb:         preprocessing//validation/chunks 1
wandb:       preprocessing//validation/finished 1
wandb:      preprocessing//validation/input_ids 28672
wandb:           preprocessing//validation/rows 28
wandb:         preprocessing//validation/shards 1
wandb: preprocessing//validation/token_type_ids 28672
wandb:                   throughput/device_kind NVIDIA A10G
wandb:             throughput/flops_per_example 2514493636608.0
wandb:             throughput/theoretical_flops 125000000000000.0
wandb:  throughput/theoretical_flops_per_device 125000000000000.0
wandb: 
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /levanter/wandb/offline-run-20240725_213448-oz48brcz
wandb: Find logs at: ./wandb/offline-run-20240725_213448-oz48brcz/logs
wandb: WARNING The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require("core")`! See https://wandb.me/wandb-core for more information.
2024-07-25 21:36:03,171 INFO worker.py:1779 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
2024-07-25T21:36:04 - 0 - ShardCache.cache/train - shard_cache.py:1418 - ERROR :: Error while reading from shard cache.
Traceback (most recent call last):
  File "/levanter/src/levanter/data/shard_cache.py", line 1406, in iter_batches_from_chunks
    chunk = self._get_chunk_unmapped(i)
  File "/levanter/src/levanter/data/shard_cache.py", line 1336, in _get_chunk_unmapped
    chunk = ray.get(self._broker.get_chunk.remote(mapped_index), timeout=current_timeout)
  File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 202, in remote
    return self._remote(args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/util/tracing/tracing_helper.py", line 426, in _start_span
    return method(self, args, kwargs, *_args, **_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 330, in _remote
    return invocation(args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 311, in invocation
    return actor._actor_method_call(
  File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 1460, in _actor_method_call
    object_refs = worker.core_worker.submit_actor_task(
  File "python/ray/_raylet.pyx", line 4258, in ray._raylet.CoreWorker.submit_actor_task
  File "python/ray/_raylet.pyx", line 4313, in ray._raylet.CoreWorker.submit_actor_task
Exception: Failed to submit task to actor ActorID(0e2b26ec3dd9161148784ee601000000) due to b"Can't find actor 0e2b26ec3dd9161148784ee601000000. It might be dead or it's from a different cluster"
Exception in thread ray_print_logs:
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 893, in print_logs
    subscriber.subscribe()
  File "python/ray/_raylet.pyx", line 3111, in ray._raylet._GcsSubscriber.subscribe
  File "python/ray/_raylet.pyx", line 586, in ray._raylet.check_status
ray.exceptions.RpcError: recvmsg:Connection reset by peer
dlwh commented 1 month ago

this seems like a probelm with the cuda stuff. @DwarKapex Any thoughts?

dlwh commented 1 month ago

(You can probably use https://github.com/orgs/nvidia/packages/container/jax/248105500?tag=levanter-2024-07-24 as jax:levanter-2024-07-24)

MikeMpapa commented 1 month ago

that did the trick! Thanks so much!

MikeMpapa commented 1 month ago

If it is of any help to you here is a slightly more detailed description of what I am trying to do.

Sharing this in case it helps you guys with debugging. Feel free to follow up if you have any question.

dlwh commented 1 month ago

weird. that's easy to workaround but I don't know why it's happening