jthickstun / anticipation

Anticipatory Autoregressive Models
Apache License 2.0
150 stars 28 forks source link

Fine-Tuning with Levanter + Anticipation: Environment Setup + safetensors vs pytorch_model.bin #14

Open ianberman opened 7 months ago

ianberman commented 7 months ago

Hello,

I created a new conda environment in wsl and proceeded to install levanter and anticipation from source. I was able to adapt and run the tokenization scripts for my needs and produced what seem to be non-empty tokenized files to use for fine-tuning.

However, when I install from requirements.txt after I install levanter, I get the following warning: levanter 1.1 requires transformers>=4.32.0, but you have transformers 4.29.2 which is incompatible.

If I proceed with training anyway, I get the following error:

Traceback (most recent call last):
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 194, in <module>
    levanter.config.main(main)()
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/config.py", line 83, in wrapper_inner
    cfg = parse(config_class=argtype, config_path=config_path, args=cmdline_args)
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/site-packages/draccus/argparsing.py", line 186, in parse
    parser = ArgumentParser(
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/site-packages/draccus/argparsing.py", line 73, in __init__
    self._set_dataclass(config_class)  # type: ignore
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/site-packages/draccus/argparsing.py", line 87, in _set_dataclass
    new_wrapper.register_actions(parser=self.parser)
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/site-packages/draccus/wrappers/dataclass_wrapper.py", line 64, in register_actions
    child.register_actions(parser)
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/site-packages/draccus/wrappers/choice_wrapper.py", line 61, in register_actions
    children = self._children
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/functools.py", line 981, in __get__
    val = self.func(instance)
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/site-packages/draccus/wrappers/choice_wrapper.py", line 97, in _children
    return {name: _wrap_child(child) for name, child in self.choice_type.get_known_choices().items()}
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/site-packages/draccus/choice_types.py", line 190, in get_known_choices
    cls._discover_packages()
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/site-packages/draccus/choice_types.py", line 209, in _discover_packages
    importlib.import_module(pkg_name)
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/models/mistral.py", line 29, in <module>
    from transformers import MistralConfig as HfMistralConfig  # noqa: E402
ImportError: cannot import name 'MistralConfig' from 'transformers' (/home/ian/miniconda3/envs/anticipation/lib/python3.10/site-packages/transformers/__init__.py)

This seems to indicate MistralConfig is not in the transformers package 4.29.2, so I tried to update the transformers package to several future versions such as 4.34 and 4.35, but this began to create dependency conflicts between tokenizers, transformers and huggingface-hub, so I then just used the latest transformers 4.39, which seems to be OK dependancy wise between those 3 packages. When I then try training with python -m levanter.main.train_lm --config_path ./config/finetune.yaml I get the following error:

INFO:levanter.distributed:Not initializing jax.distributed because no distributed config was provided, and no cluster was detected.
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:levanter.trainer:Setting run id to gjdk5n7w
2024-03-21T13:19:48 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /mnt/c/Users/Ian/GitHub/levanter
wandb: Currently logged in as: ijb. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.4
wandb: Run data is saved locally in /mnt/c/Users/Ian/GitHub/anticipation/wandb/run-20240321_131949-gjdk5n7w
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run playful-universe-7
wandb: ⭐️ View project at https://wandb.ai/ijb/anticipation
wandb: 🚀 View run at https://wandb.ai/ijb/anticipation/runs/gjdk5n7w
2024-03-21T13:19:52 - 0 - levanter.distributed - distributed.py:215 - INFO :: No auto-discovered ray address found. Using ray.init('local').
2024-03-21T13:19:52 - 0 - levanter.distributed - distributed.py:267 - INFO :: ray.init(address='local', namespace='levanter', **{})
/home/ian/miniconda3/envs/anticipation/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _posixsubprocess.fork_exec(
/home/ian/miniconda3/envs/anticipation/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _posixsubprocess.fork_exec(
2024-03-21 13:19:54,472 INFO worker.py:1715 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265
2024-03-21T13:19:55 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /mnt/c/Users/Ian/GitHub/levanter
2024-03-21T13:19:55 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /mnt/c/Users/Ian/GitHub/levanter
train:   0%|                                                                                   | 0/2001 [00:00<?, ?it/s]2024-03-21T13:19:56 - 0 - levanter.data.shard_cache - shard_cache.py:1575 - INFO :: Loading cache from /cache/validation
2024-03-21T13:19:56 - 0 - levanter.data.text - text.py:720 - INFO :: Building cache for validation...
[]
Traceback (most recent call last):
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 194, in <module>
    levanter.config.main(main)()
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 105, in main
    eval_datasets = config.data.validation_sets(Pos.size)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/data/text.py", line 675, in validation_sets
    validation_set = self.validation_set(seq_len, monitors)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/data/text.py", line 670, in validation_set
    return self.token_seq_dataset("validation", seq_len, monitors)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/data/text.py", line 699, in token_seq_dataset
    cache = self.build_or_load_cache(split, monitors=monitors)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/data/text.py", line 730, in build_or_load_cache
    return TokenizedDocumentCache.build_or_load(
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/data/text.py", line 242, in build_or_load
    bt = BatchTokenizer(tokenizer, enforce_eos=enforce_eos, override_resources=override_resources)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/data/text.py", line 356, in __init__
    should_append_eos = input_ids[-1] != tokenizer.eos_token_id and enforce_eos
IndexError: list index out of range
Traceback (most recent call last):
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ian/miniconda3/envs/anticipation/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 194, in <module>
    levanter.config.main(main)()
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 105, in main
    eval_datasets = config.data.validation_sets(Pos.size)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/data/text.py", line 675, in validation_sets
    validation_set = self.validation_set(seq_len, monitors)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/data/text.py", line 670, in validation_set
    return self.token_seq_dataset("validation", seq_len, monitors)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/data/text.py", line 699, in token_seq_dataset
    cache = self.build_or_load_cache(split, monitors=monitors)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/data/text.py", line 730, in build_or_load_cache
    return TokenizedDocumentCache.build_or_load(
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/data/text.py", line 242, in build_or_load
    bt = BatchTokenizer(tokenizer, enforce_eos=enforce_eos, override_resources=override_resources)
  File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/data/text.py", line 356, in __init__
    should_append_eos = input_ids[-1] != tokenizer.eos_token_id and enforce_eos
IndexError: list index out of range

Lastly, I wanted to mention in your reply here https://github.com/jthickstun/anticipation/issues/7#issuecomment-1837565701

That my code for your commenting out "hack" from the latest levanter main is a bit different than the one you linked to - not sure why. The lastest version seems to be here https://github.com/stanford-crfm/levanter/blob/43712f12bddf8e2827783b6276d8d5373563d866/src/levanter/main/train_lm.py#L63

Should I check out a past version of levanter? I wasn't exactly sure based on the instructions and dates you wrote it seems to be after the latest commit to levanter.

Anyway, my hunch is that it's related to the levanter version or transformers version incompatibility, but maybe something broke in the tokenization step? Here is my finetune.yaml also. Any thoughts appreciated:)

jthickstun commented 7 months ago

I've usually set up a separate conda environment for running Levanter. That said, if you want to run both codebases in the same environment, I would start by installing Levanter & then try to get the anticipation repo working with Levanter's versions of requirements. In particular, I think it should be safe to switch to transformers >= 4.29.2 in the anticipation repo's requirements.txt. Let me know if there are any surprises running anticipation with the newer versions of libraries required by Levanter & I'll be happy to take a look.

Re: Issue #7, I tried running Levanter just now & realize that things have changed a bit since I made those comments: see my latest comment on Issue #7 for instructions on getting a music model running with the latest version of Levanter. I'm working with the Levanter team & hopeful that we'll be able to avoid needing any of these hacks in the near future.

ianberman commented 7 months ago

Thank you for the quick and helpful reply! It didn't occur to me I could just use a separate environment for levanter 😅

Now I am largely up and running using your suggested finetune.yaml file, but it seems that levanter expects a safetensors file and stanford-crfm/music-medium-800k is a pytorch_model.bin file?

terminal output: ``` 2024-03-22T00:33:50 - 0 - __main__ - train_lm.py:124 - INFO :: No training checkpoint found. Initializing model from HF checkpoint 'stanford-crfm/music-medium-800k' pytorch_model.bin: 100%|██████████████████████████████████████████████████████████████████| 1.44G/1.44G [01:57<00:00, 12.3MB/s] Traceback (most recent call last):████████████████████████████████████████████████████████| 1.44G/1.44G [01:57<00:00, 11.9MB/s] File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 304, in hf_raise_for_status response.raise_for_status() File "/home/ian/.local/lib/python3.10/site-packages/requests/models.py", line 1021, in raise_for_status raise HTTPError(http_error_msg, response=self) requests.exceptions.HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/stanford-crfm/music-medium-800k/resolve/main/model.safetensors The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/hf_checkpoints.py", line 460, in load_state_dict model_path = hf_hub_download(id, SAFE_TENSORS_MODEL, revision=rev) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn return fn(*args, **kwargs) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1261, in hf_hub_download metadata = get_hf_file_metadata( File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn return fn(*args, **kwargs) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1667, in get_hf_file_metadata r = _request_wrapper( File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 385, in _request_wrapper response = _request_wrapper( File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 409, in _request_wrapper hf_raise_for_status(response) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 315, in hf_raise_for_status raise EntryNotFoundError(message, response) from e huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: Root=1-65fcc3e0-0cc9513e4e7e5d901bd2db22;21b8fc2f-5aaf-4d2f-b65c-b9809dcaaa4e) Entry Not Found for url: https://huggingface.co/stanford-crfm/music-medium-800k/resolve/main/model.safetensors. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/home/ian/miniconda3/envs/levanter/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/ian/miniconda3/envs/levanter/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 194, in levanter.config.main(main)() File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/config.py", line 84, in wrapper_inner response = fn(cfg, *args, **kwargs) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 131, in main model = converter.load_pretrained( File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/hf_checkpoints.py", line 540, in load_pretrained state_dict = self.load_state_dict(ref, dtype) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/hf_checkpoints.py", line 464, in load_state_dict state_dict = _load_torch(model_path, dtype) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/hf_checkpoints.py", line 180, in _load_torch import torch ModuleNotFoundError: No module named 'torch' Traceback (most recent call last): File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 304, in hf_raise_for_status response.raise_for_status() File "/home/ian/.local/lib/python3.10/site-packages/requests/models.py", line 1021, in raise_for_status raise HTTPError(http_error_msg, response=self) requests.exceptions.HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/stanford-crfm/music-medium-800k/resolve/main/model.safetensors The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/hf_checkpoints.py", line 460, in load_state_dict model_path = hf_hub_download(id, SAFE_TENSORS_MODEL, revision=rev) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn return fn(*args, **kwargs) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1261, in hf_hub_download metadata = get_hf_file_metadata( File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn return fn(*args, **kwargs) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 1667, in get_hf_file_metadata r = _request_wrapper( File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 385, in _request_wrapper response = _request_wrapper( File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/file_download.py", line 409, in _request_wrapper hf_raise_for_status(response) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py", line 315, in hf_raise_for_status raise EntryNotFoundError(message, response) from e huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error. (Request ID: Root=1-65fcc3e0-0cc9513e4e7e5d901bd2db22;21b8fc2f-5aaf-4d2f-b65c-b9809dcaaa4e) Entry Not Found for url: https://huggingface.co/stanford-crfm/music-medium-800k/resolve/main/model.safetensors. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/home/ian/miniconda3/envs/levanter/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/ian/miniconda3/envs/levanter/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 194, in levanter.config.main(main)() File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/config.py", line 84, in wrapper_inner response = fn(cfg, *args, **kwargs) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 131, in main model = converter.load_pretrained( File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/hf_checkpoints.py", line 540, in load_pretrained state_dict = self.load_state_dict(ref, dtype) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/hf_checkpoints.py", line 464, in load_state_dict state_dict = _load_torch(model_path, dtype) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/hf_checkpoints.py", line 180, in _load_torch import torch ModuleNotFoundError: No module named 'torch' ```

edit: I tried converting it to safetensors myself but am getting the error: AttributeError: 'NoneType' object has no attribute 'data_ptr'

edit 2: Related issue -- I'm following along with the colab example in my local environment, and noticed that the large model being in safetensors format becomes an issue :

LARGE_MODEL = 'stanford-crfm/music-large-800k'
model = AutoModelForCausalLM.from_pretrained(LARGE_MODEL).cuda()

error: OSError: stanford-crfm/music-large-800k does not appear to have a file named pytorch_model.bin, tf_model.h5, model.ckpt or flax_model.msgpack.

jthickstun commented 7 months ago

Frustrating. This change in checkpoint format happened between when I trained the smaller models, and the new large model I just released. In my own testing, it looked like the old .bin and new .safetensors checkpoints were largely substitutable for one another. But clearly there can be some issues.

it seems that levanter expects a safetensors file and stanford-crfm/music-medium-800k is a pytorch_model.bin file?

Oh: you just need to install torch, e.g., pip install torch (it's not one of the default requirements for Levanter, but it's used to load checkpoints).

Re: the error loading stanford-crfm/music-large-800k. This is indeed stored in the new .safetensor format (the order models are stored using the old .bin). My first thought is that the pegged version transformers 4.29.2 may be too old to support safetensor checkpoints. It seemed to be working in colab, but possibly that environment is ignoring the version request & installing something newer? Upgrading your local transformers library might fix this?

ianberman commented 7 months ago

Cool, thank you. I didn't realize it was as easy as installing torch!

If it helps anyone else, I had to install everything (jax, torch, etc) for cuda 11.8; the jax for cuda 12 is too recent for torch.

I now have fine-tuning the medium model up and running on my 3090. However, it's quite slow - about 40s/it with 512 batch size; half that with 256.

I'm getting the warning /mnt/c/Users/Ian/GitHub/levanter/src/levanter/models/attention.py:89: UserWarning: transformer_engine is not installed. Please install it to use NVIDIA's optimized fused attention. Falling back to the reference implementation.

however i can enter python and import transformer_engine:

(levanter) ian@DESKTOP-MDE3NSV:~$ python
Python 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import transformer_engine
/home/ian/miniconda3/envs/levanter/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
>>>

maybe this is a levanter issue, let me know if i should file an issue over there, or if I am overlooking something.

here is the full terminal output: ```bash (levanter) ian@DESKTOP-MDE3NSV:/mnt/c/Users/Ian/GitHub/anticipation$ python -m levanter.main.train_lm --config_path config/finetune.yaml INFO:levanter.distributed:Not initializing jax.distributed because no distributed config was provided, and no cluster was detected. INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory INFO:levanter.trainer:Setting run id to vfrzbivg 2024-03-23T16:28:56 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /mnt/c/Users/Ian/GitHub/levanter wandb: Currently logged in as: ijb. Use `wandb login --relogin` to force relogin wandb: Tracking run with wandb version 0.16.4 wandb: Run data is saved locally in /mnt/c/Users/Ian/GitHub/anticipation/wandb/run-20240323_162857-vfrzbivg wandb: Run `wandb offline` to turn off syncing. wandb: Syncing run misunderstood-moon-20 wandb: ⭐️ View project at https://wandb.ai/ijb/anticipation wandb: 🚀 View run at https://wandb.ai/ijb/anticipation/runs/vfrzbivg 2024-03-23T16:29:00 - 0 - levanter.distributed - distributed.py:215 - INFO :: No auto-discovered ray address found. Using ray.init('local'). 2024-03-23T16:29:00 - 0 - levanter.distributed - distributed.py:267 - INFO :: ray.init(address='local', namespace='levanter', **{}) /home/ian/miniconda3/envs/levanter/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = _posixsubprocess.fork_exec( /home/ian/miniconda3/envs/levanter/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = _posixsubprocess.fork_exec( 2024-03-23 16:29:02,440 INFO worker.py:1743 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 2024-03-23T16:29:03 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /mnt/c/Users/Ian/GitHub/levanter 2024-03-23T16:29:03 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /mnt/c/Users/Ian/GitHub/levanter train: 0%| | 0/2001 [00:00
jthickstun commented 7 months ago

I now have fine-tuning the medium model up and running on my 3090. However, it's quite slow - about 40s/it with 512 batch size; half that with 256.

You could definitely try running with a smaller batch size! The batch size of 512 was copied from my pre-training configuration and is likely larger (maybe much larger) than it needs to be for effective finetuning. Lots of people have success fine-tuning language and image models with small batch sizes & I would expect similar results for fine-tuning these music models.

If it helps anyone else, I had to install everything (jax, torch, etc) for cuda 11.8; the jax for cuda 12 is too recent for torch.

It should be possible to get things running on cuda 12. I'm running everything on cuda 12 on linux); but maybe there are some additional complexities to running things in wsl.

maybe this is a levanter issue, let me know if i should file an issue over there, or if I am overlooking something.

Yeah, I don't know about the transformer_engine details; that's all internal to Levanter.

ianberman commented 7 months ago

Thanks so much for your help with all of this!

It should be possible to get things running on cuda 12. I'm running everything on cuda 12 on linux); but maybe there are some additional complexities to running things in wsl.

Just to follow-up on this, my issue was that, if I installed torch with packaged cuda 12.1, jax with I think 12.3 is conflicting and then it can't train with cuda at all, so I just installed cuda 11.8 for both jax and torch. I tried using an older version of jax for cuda 12.1 (jax==0.4.16) but ran into issues doing this - ImportError: cannot import name 'DTypeLike' from 'jax.typing' (/usr/local/lib/python3.10/dist-packages/jax/typing.py)


Anyway, I've been trying fine-tuning runs with a batch size of 64 or 32 just to test it out. And I get as far as saving the second checkpoint. However then I get a crash with the following output. I tried on a wsl environment and in a docker environment (with wsl backend) which was installed as in levanter's instructions - same issue.

terminal output ```bash train: 38%|███████████████████████▏ | 760/2001 [30:21<48:56, 2.37s/it, loss=0.153]2024-03-24T16:15:36 - 0 - levanter.checkpoint - checkpoint.py:194 - INFO :: Saving temporary checkpoint at step 759. 2024-03-24T16:15:36 - 0 - levanter.checkpoint - checkpoint.py:245 - INFO :: Saving checkpoint at step 759 to /mnt/c/Users/Ian/GitHub/anticipation/checkpoints/fdw2hoae/step-759 2024-03-24T16:15:36 - 0 - levanter.checkpoint - checkpoint.py:285 - INFO :: Saving checkpoint to /mnt/c/Users/Ian/GitHub/anticipation/checkpoints/fdw2hoae/step-759 for step 759 2024-03-24T16:15:36 - 0 - jax.experimental.array_serialization.serialization - serialization.py:529 - INFO :: Waiting for previous serialization to finish. 2024-03-24T16:15:36 - 0 - jax.experimental.array_serialization.serialization - serialization.py:488 - INFO :: Error check finished successfully 2024-03-24T16:15:42 - 0 - jax.experimental.array_serialization.serialization - serialization.py:430 - INFO :: Starting commit to storage layer by process: 0 train: 39%|███████████████████████▌ | 772/2001 [30:56<49:01, 2.39s/it, loss=0.203](raylet) [2024-03-24 16:16:13,409 E 10547 10547] (raylet) node_manager.cc:2967: 14 Workers (tasks / actors) killed due to memory pressure (OOM), 0 Workers crashed due to other reasons at node (ID: e7fbc0805b72529f05da5a3ad8c944bcdc58af6d1f602358b987438c, IP: 172.18.67.159) over the last time period. To see more information about the Workers killed on this node, use `ray logs raylet.out -ip 172.18.67.159` (raylet) (raylet) Refer to the documentation on how to address the out of memory issue: https://docs.ray.io/en/latest/ray-core/scheduling/ray-oom-prevention.html. Consider provisioning more memory on this node or reducing task parallelism by requesting more CPUs per task. To adjust the kill threshold, set the environment variable `RAY_memory_usage_threshold` when starting Ray. To disable worker killing, set the environment variable `RAY_memory_monitor_refresh_ms` to zero. train: 50%|██████████████████████████████ | 1001/2001 [39:55<39:12, 2.35s/it, loss=0.166]2024-03-24T16:25:10 - 0 - levanter.checkpoint - checkpoint.py:192 - INFO :: Saving checkpoint at step 1000. 2024-03-24T16:25:10 - 0 - levanter.checkpoint - checkpoint.py:245 - INFO :: Saving checkpoint at step 1000 to /mnt/c/Users/Ian/GitHub/anticipation/checkpoints/fdw2hoae/step-1000 2024-03-24T16:25:10 - 0 - levanter.checkpoint - checkpoint.py:285 - INFO :: Saving checkpoint to /mnt/c/Users/Ian/GitHub/anticipation/checkpoints/fdw2hoae/step-1000 for step 1000 2024-03-24T16:25:10 - 0 - jax.experimental.array_serialization.serialization - serialization.py:529 - INFO :: Waiting for previous serialization to finish. 2024-03-24T16:25:10 - 0 - jax.experimental.array_serialization.serialization - serialization.py:485 - INFO :: Thread joined successfully Traceback (most recent call last): File "/home/ian/miniconda3/envs/levanter/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/ian/miniconda3/envs/levanter/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 194, in levanter.config.main(main)() File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/config.py", line 84, in wrapper_inner response = fn(cfg, *args, **kwargs) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 189, in main trainer.train(state, train_loader) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/trainer.py", line 403, in train for info in self.training_steps(state, train_loader, run_hooks=run_hooks): File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/trainer.py", line 391, in training_steps self.run_hooks(info) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/trainer.py", line 231, in run_hooks self.hooks.run_hooks(info, force=force) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/trainer.py", line 108, in run_hooks hook.fn(info) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/checkpoint.py", line 209, in on_step self.save_checkpoint(info, destination, commit_callback=callback) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/checkpoint.py", line 248, in save_checkpoint save_checkpoint( File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/checkpoint.py", line 300, in save_checkpoint tree_serialize_leaves_tensorstore(checkpoint_path, tree, manager, commit_callback=my_callback) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/tensorstore_serialization.py", line 76, in tree_serialize_leaves_tensorstore manager.serialize_with_paths(arrays, paths, on_commit_callback=commit_callback) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py", line 550, in serialize_with_paths self.serialize(arrays, tspecs, on_commit_callback=on_commit_callback) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py", line 530, in serialize self.wait_until_finished() File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py", line 487, in wait_until_finished self.check_for_errors() File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py", line 479, in check_for_errors raise exception # pylint: disable=raising-bad-type File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py", line 434, in _thread_func future.result() ValueError: NOT_FOUND: Error opening "zarr" driver: Error writing local file "/mnt/c/Users/Ian/GitHub/anticipation/checkpoints/fdw2hoae/step-759/step/.zarray": Error getting file info: /mnt/c/Users/Ian/GitHub/anticipation/checkpoints/fdw2hoae/step-759/step/.zarray.__lock [OS error: No such file or directory] [tensorstore_spec='{\"context\":{\"cache_pool\":{},\"data_copy_concurrency\":{},\"file_io_concurrency\":{\"limit\":128},\"file_io_sync\":true},\"create\":true,\"driver\":\"zarr\",\"dtype\":\"int32\",\"kvstore\":{\"driver\":\"file\",\"path\":\"/mnt/c/Users/Ian/GitHub/anticipation/checkpoints/fdw2hoae/step-759/step/\"},\"metadata\":{\"chunks\":[],\"compressor\":{\"id\":\"zstd\",\"level\":1},\"dtype\":\" levanter.config.main(main)() File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/config.py", line 84, in wrapper_inner response = fn(cfg, *args, **kwargs) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 189, in main trainer.train(state, train_loader) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/trainer.py", line 403, in train for info in self.training_steps(state, train_loader, run_hooks=run_hooks): File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/trainer.py", line 391, in training_steps self.run_hooks(info) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/trainer.py", line 231, in run_hooks self.hooks.run_hooks(info, force=force) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/trainer.py", line 108, in run_hooks hook.fn(info) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/checkpoint.py", line 209, in on_step self.save_checkpoint(info, destination, commit_callback=callback) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/checkpoint.py", line 248, in save_checkpoint save_checkpoint( File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/checkpoint.py", line 300, in save_checkpoint tree_serialize_leaves_tensorstore(checkpoint_path, tree, manager, commit_callback=my_callback) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/tensorstore_serialization.py", line 76, in tree_serialize_leaves_tensorstore manager.serialize_with_paths(arrays, paths, on_commit_callback=commit_callback) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py", line 550, in serialize_with_paths self.serialize(arrays, tspecs, on_commit_callback=on_commit_callback) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py", line 530, in serialize self.wait_until_finished() File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py", line 487, in wait_until_finished self.check_for_errors() File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py", line 479, in check_for_errors raise exception # pylint: disable=raising-bad-type File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py", line 434, in _thread_func future.result() ValueError: NOT_FOUND: Error opening "zarr" driver: Error writing local file "/mnt/c/Users/Ian/GitHub/anticipation/checkpoints/fdw2hoae/step-759/step/.zarray": Error getting file info: /mnt/c/Users/Ian/GitHub/anticipation/checkpoints/fdw2hoae/step-759/step/.zarray.__lock [OS error: No such file or directory] [tensorstore_spec='{\"context\":{\"cache_pool\":{},\"data_copy_concurrency\":{},\"file_io_concurrency\":{\"limit\":128},\"file_io_sync\":true},\"create\":true,\"driver\":\"zarr\",\"dtype\":\"int32\",\"kvstore\":{\"driver\":\"file\",\"path\":\"/mnt/c/Users/Ian/GitHub/anticipation/checkpoints/fdw2hoae/step-759/step/\"},\"metadata\":{\"chunks\":[],\"compressor\":{\"id\":\"zstd\",\"level\":1},\"dtype\":\"

I tried on two different runs, and it seems like on both runs it crashes when saving the 2nd checkpoint. It is able to write a bunch of files to the target directory the first time around, including the .zarray file, but I don't see a lock file. the .zarray file contains the below text: {"chunks":[],"compressor":{"id":"zstd","level":1},"dimension_separator":".","dtype":"<i4","fill_value":null,"filters":null,"order":"C","shape":[],"zarr_format":2}


Lastly, with regards to loading the large model, I'm using a separate levanter conda environment now, which I believe uses a recent transformers, so it's no trouble opening a safetensors file. However, I'm getting the following error related to configuration, maybe I just need to change something in the finetune.yaml for this to work?

terminal output ```bash Loading weights: 100%|███████████████████████████████████████████████████████████████| 436/436 [00:03<00:00, 138.81it/s]jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/ian/miniconda3/envs/levanter/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/ian/miniconda3/envs/levanter/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 194, in levanter.config.main(main)() File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/config.py", line 84, in wrapper_inner response = fn(cfg, *args, **kwargs) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 131, in main model = converter.load_pretrained( File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/hf_checkpoints.py", line 578, in load_pretrained lev_model = eqx.filter_jit(load_from_state_dict, donate="all", device=cpu_device)(state_dict) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/equinox/_jit.py", line 206, in __call__ return self._call(False, args, kwargs) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/equinox/_module.py", line 935, in __call__ return self.__func__(self.__self__, *args, **kwargs) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/equinox/_jit.py", line 198, in _call out = self._cached(dynamic_donate, dynamic_nodonate, static) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/hf_checkpoints.py", line 551, in load_from_state_dict lev_model = lev_model.from_state_dict(state_dict, prefix=ignore_prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 63, in from_state_dict return default_eqx_module_from_state_dict(self, state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 168, in default_eqx_module_from_state_dict new = jax_tree_from_state_dict(value, state_dict, apply_prefix(prefix, key)) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 77, in jax_tree_from_state_dict return tree.from_state_dict(state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/models/gpt2.py", line 296, in from_state_dict out = super().from_state_dict(stacked, prefix=prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 63, in from_state_dict return default_eqx_module_from_state_dict(self, state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 168, in default_eqx_module_from_state_dict new = jax_tree_from_state_dict(value, state_dict, apply_prefix(prefix, key)) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 79, in jax_tree_from_state_dict return default_eqx_module_from_state_dict(tree, state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 168, in default_eqx_module_from_state_dict new = jax_tree_from_state_dict(value, state_dict, apply_prefix(prefix, key)) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 77, in jax_tree_from_state_dict return tree.from_state_dict(state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 63, in from_state_dict return default_eqx_module_from_state_dict(self, state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 168, in default_eqx_module_from_state_dict new = jax_tree_from_state_dict(value, state_dict, apply_prefix(prefix, key)) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 79, in jax_tree_from_state_dict return default_eqx_module_from_state_dict(tree, state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 168, in default_eqx_module_from_state_dict new = jax_tree_from_state_dict(value, state_dict, apply_prefix(prefix, key)) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 103, in jax_tree_from_state_dict array = haliax.named(array, tree.axes) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/haliax/core.py", line 1162, in named axes = check_shape(a.shape, axis) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/haliax/core.py", line 1394, in check_shape raise ValueError(f"Shape mismatch: jnp_shape={jnp_shape} hax_axes={hax_axes}") ValueError: Shape mismatch: jnp_shape=(36, 1280) hax_axes=(Axis(name='layers', size=24), Axis(name='embed', size=1024)) jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/ian/miniconda3/envs/levanter/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/ian/miniconda3/envs/levanter/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 194, in levanter.config.main(main)() File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/config.py", line 84, in wrapper_inner response = fn(cfg, *args, **kwargs) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/main/train_lm.py", line 131, in main model = converter.load_pretrained( File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/hf_checkpoints.py", line 578, in load_pretrained lev_model = eqx.filter_jit(load_from_state_dict, donate="all", device=cpu_device)(state_dict) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/equinox/_jit.py", line 206, in __call__ return self._call(False, args, kwargs) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/equinox/_module.py", line 935, in __call__ return self.__func__(self.__self__, *args, **kwargs) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/equinox/_jit.py", line 198, in _call out = self._cached(dynamic_donate, dynamic_nodonate, static) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/hf_checkpoints.py", line 551, in load_from_state_dict lev_model = lev_model.from_state_dict(state_dict, prefix=ignore_prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 63, in from_state_dict return default_eqx_module_from_state_dict(self, state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 168, in default_eqx_module_from_state_dict new = jax_tree_from_state_dict(value, state_dict, apply_prefix(prefix, key)) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 77, in jax_tree_from_state_dict return tree.from_state_dict(state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/models/gpt2.py", line 296, in from_state_dict out = super().from_state_dict(stacked, prefix=prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 63, in from_state_dict return default_eqx_module_from_state_dict(self, state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 168, in default_eqx_module_from_state_dict new = jax_tree_from_state_dict(value, state_dict, apply_prefix(prefix, key)) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 79, in jax_tree_from_state_dict return default_eqx_module_from_state_dict(tree, state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 168, in default_eqx_module_from_state_dict new = jax_tree_from_state_dict(value, state_dict, apply_prefix(prefix, key)) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 77, in jax_tree_from_state_dict return tree.from_state_dict(state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 63, in from_state_dict return default_eqx_module_from_state_dict(self, state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 168, in default_eqx_module_from_state_dict new = jax_tree_from_state_dict(value, state_dict, apply_prefix(prefix, key)) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 79, in jax_tree_from_state_dict return default_eqx_module_from_state_dict(tree, state_dict, prefix) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 168, in default_eqx_module_from_state_dict new = jax_tree_from_state_dict(value, state_dict, apply_prefix(prefix, key)) File "/mnt/c/Users/Ian/GitHub/levanter/src/levanter/compat/torch_serialization.py", line 103, in jax_tree_from_state_dict array = haliax.named(array, tree.axes) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/haliax/core.py", line 1162, in named axes = check_shape(a.shape, axis) File "/home/ian/miniconda3/envs/levanter/lib/python3.10/site-packages/haliax/core.py", line 1394, in check_shape raise ValueError(f"Shape mismatch: jnp_shape={jnp_shape} hax_axes={hax_axes}") ValueError: Shape mismatch: jnp_shape=(36, 1280) hax_axes=(Axis(name='layers', size=24), Axis(name='embed', size=1024)) ```
jthickstun commented 7 months ago

Lastly, with regards to loading the large model [...] maybe I just need to change something in the finetune.yaml for this to work?

This I can help with: yes, you need to change the finetune.yaml, which is currently configured to describe a the architecture of the medium model. For the large model, you'll need to update the following values in the config:

  hidden_dim: 1280
  num_heads: 20
  num_layers: 36

I get as far as saving the second checkpoint. However then I get a crash with the following output.

We are getting into questions here that might be better addressed on the Levanter issues board. I think you might be the first person to ever try running Levanter on wsl!

This error looks like it might be downstream of an out-of-memory issue:

(raylet) [2024-03-24 16:16:13,409 E 10547 10547] (raylet) node_manager.cc:2967: 14 Workers (tasks / actors) killed due to memory pressure (OOM)
MikeMpapa commented 3 months ago

Joining a bit late but what would be the appropriate torch version to use??

It was running fine for me util yesterday just by running pip install torch but suddenly that causes a CuDNN incompatibility error .

Loaded runtime CuDNN library: 9.1.0 but source was compiled with: 9.2.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.

EDIT: As of the day of this comment it work fine with torch==2.3.0