google-research / timesfm

TimesFM (Time Series Foundation Model) is a pretrained time-series foundation model developed by Google Research for time-series forecasting.
https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/
Apache License 2.0
3.02k stars 228 forks source link

Error in loading checkpoint #43

Closed yanghou2000 closed 1 month ago

yanghou2000 commented 1 month ago

Background

Linux x86 timesfm cpu version use slurm to submit job. already ensure that conda env is activated after using SBATCH and before running python code

Code that ran into error

tfm = timesfm.TimesFm(
        context_len=480,
        horizon_len=14,
        input_patch_len=32, # fixed
        output_patch_len=128, # fixed
        num_layers=20, # fixed
        model_dims=1280, # fixed
        backend="cpu",
        )

tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path, step=1100000)

Description

I downloaded the model checkpoint from a huggingface mirror website, and stored to this path: /repo/timesfm_model/checkpoints/checkpoint_1100000/state/checkpoint. I'm not sure what is the right path to input checkpoint_path in tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)

The questions i want to ask is

  1. What should be my_checkpoint_path in my case? I tried all possible choices and didn't work out, with error messages showing below
  2. What is the size of the checkpoint for those of you having a working example? I use curl -L to download the checkpoint from a mirror website to the server, and the size is 777 Mb, which is weird as if I download from the same link directly to my local machine (Mac), the size is 814.3 Mb.

Error message

  1. When I use

    my_checkpoint_path = "/repo/timesfm_model/checkpoints"
    tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)

    the corresponding error message is:

    WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not in
    stalled. Falling back to cpu.
    WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHan
    dler'>
    WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at ht
    tps://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
    WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.
    ERROR:absl:For checkpoint version > 1.0, we require users to provide
          `train_state_unpadded_shape_dtype_struct` during checkpoint
          saving/restoring, to avoid potential silent bugs when loading
          checkpoints to incompatible unpadded shapes of TrainState.
  2. When i use

    my_checkpoint_path = "/repo/timesfm_model/checkpoints/checkpoint_1100000" 
    # or my my_checkpoint_path = "/repo/timesfm_model/checkpoints/checkpoint_1100000/state" 
    tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)

    the error message is like this

    WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not in
    stalled. Falling back to cpu.
    WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHan
    dler'>
    WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at ht
    tps://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
    WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.
    Constructing model weights.
    Constructed model weights in 2.28 seconds.
    Restoring checkpoint from /repo/timesfm_model/checkpoints/checkpoint_1100000.
    Traceback (most recent call last):
    File "/./bin/python_script/timesfm_pred.py", line 92, in <module>
    main()
    File "/./bin/python_script/timesfm_pred.py", line 72, in main
    tfm.load_from_checkpoint(checkpoint_path="/ssd1/cache/hpc_t0/hy/repo/timesfm_model/checkpoints/checkpoint
    _1100000", step=1100000)
    File "repo/timesfm/src/timesfm.py", line 270, in load_from_checkpoint
    self._train_state = checkpoints.restore_checkpoint(
    File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoints.py", line 246, in
    restore_checkpoint
    output = checkpoint_manager.restore(
    File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py", line
    568, in restore
    restored = self._manager.restore(
    File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager
    .py", line 1054, in restore
    restore_directory = self._get_read_step_directory(step, directory)
    File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager
    .py", line 811, in _get_read_step_directory
    return self._options.step_name_format.find_step(root_dir, step).path
    File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/path/step.py", lin
    e 66, in find_step
    raise ValueError(
    ValueError: No step path found for step=1100000 with NameFormat=PaxStepNameFormat(checkpoint_type=<Checkpoint
    Type.FLAX: 'flax'>, use_digit_step_subdirectory=False) under /ssd1/cache/hpc_t0/hy/repo/timesfm_model/checkpo
    ints/checkpoint_1100000
  3. When i use

    my_checkpoint_path = "/repo/timesfm_model/checkpoints/checkpoint_1100000/state/checkpoint" 
    tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)

    the error message is

    WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not in
    stalled. Falling back to cpu.
    WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHan
    dler'>
    Constructing model weights.
    Constructed model weights in 2.42 seconds.
    Restoring checkpoint from /repo/timesfm_model/checkpoints/checkpoint_1100000/state/check
    point.
    Traceback (most recent call last):
    File "/./bin/python_script/timesfm_pred.py", line 92, in <module>
    main()
    File "/./bin/python_script/timesfm_pred.py", line 72, in main
    tfm.load_from_checkpoint(checkpoint_path="/ssd1/cache/hpc_t0/hy/repo/timesfm_model/checkpoints/checkpoint
    _1100000/state/checkpoint", step=1100000)
    File "/ssd1/cache/hpc_t0/hy/repo/timesfm/src/timesfm.py", line 270, in load_from_checkpoint
    self._train_state = checkpoints.restore_checkpoint(
    File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoints.py", line 227, in
    restore_checkpoint
    checkpoint_manager = checkpoint_managers.OrbaxCheckpointManager(
    File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py", line
    451, in __init__
    self._manager = _CheckpointManagerImpl(
    File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py", line
    290, in __init__
    step = self.any_step()
    File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py", line
    370, in any_step
    any_step = ocp.utils.any_checkpoint_step(self.directory)
    File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/utils.py", line 71
    4, in any_checkpoint_step
    for s in checkpoint_dir.iterdir():
    File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/etils/epath/gpath.py", line 156, in
    iterdir
    for f in self._backend.listdir(self._path_str):
    File "/envs/tfm_env/lib/python3.10/site-packages/etils/epath/backend.py", line 142,
    in listdir
    return [p for p in os.listdir(path) if not p.endswith('~')]
    NotADirectoryError: [Errno 20] Not a directory: '/ssd1/cache/hpc_t0/hy/repo/timesfm_model/checkpoints/checkpo
    int_1100000/state/checkpoint'
siriuz42 commented 1 month ago
my_checkpoint_path = "/repo/timesfm_model/checkpoints"
tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)

is the correct way to call. One "checkpoint" is the set of everything under checkpoint_1100000, and here we point to the parent directory.

Quick question, can you run model inference despite the error message in (1), or is it interrupting?

yanghou2000 commented 1 month ago
my_checkpoint_path = "/repo/timesfm_model/checkpoints"
tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)

is the correct way to call. One "checkpoint" is the set of everything under checkpoint_1100000, and here we point to the parent directory.

Quick question, can you run model inference despite the error message in (1), or is it interrupting?

Thank you for your swift reply! After using the parent directory, the model inference can be ran after I give more memory in SBATCH when submitting the slurm job. In other words, the previous error in (1) is caused by out of memory issue instead of any bug in the code.

yanghou2000 commented 1 month ago

Let me summerize and close this issue for now.

Summary

My working example code is as below:

# Load timesfm model
    tfm = timesfm.TimesFm(
        context_len=480,
        horizon_len=14,
        input_patch_len=32, # fixed
        output_patch_len=128, # fixed
        num_layers=20, # fixed
        model_dims=1280, # fixed
        backend="cpu",
        )

tfm.load_from_checkpoint(checkpoint_path="/repo/timesfm_model/checkpoints")
xyskywalker commented 2 weeks ago

I also encountered the same problem, and I also made sure that the directory was correct, but it still prompted this error:

[*********************100%%**********************]  1 of 1 completed
Constructing model weights.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHandler'>
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.
Constructed model weights in 2.20 seconds.
Restoring checkpoint from /home/wd/PycharmProjects/timesfm-1.0-200m/checkpoints/.
Restored checkpoint in 4.69 seconds.
Jitting decoding.
ERROR:absl:For checkpoint version > 1.0, we require users to provide
          `train_state_unpadded_shape_dtype_struct` during checkpoint
          saving/restoring, to avoid potential silent bugs when loading
          checkpoints to incompatible unpadded shapes of TrainState.

Process finished with exit code 137 (interrupted by signal 9:SIGKILL)
LouisLee1983 commented 2 days ago

WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHandler'> WARNING:absl:Configured CheckpointManager using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024. WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume train_state is unpadded. ERROR:absl:For checkpoint version > 1.0, we require users to provide train_state_unpadded_shape_dtype_struct during checkpoint saving/restoring, to avoid potential silent bugs when loading checkpoints to incompatible unpadded shapes of TrainState. Restored checkpoint in 0.75 seconds. Jitting decoding. Killed

请问这个错误提示是哪里出了问题?我的:orbax-checkpoint是0.5.9版本。

LouisLee1983 commented 18 hours ago

各位朋友,找到原因了。是wsl的内存不够,需要把wsl的内存搞到16g以上。

yanghou2000 commented 9 hours ago

I also encountered the same problem, and I also made sure that the directory was correct, but it still prompted this error:

[*********************100%%**********************]  1 of 1 completed
Constructing model weights.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHandler'>
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.
Constructed model weights in 2.20 seconds.
Restoring checkpoint from /home/wd/PycharmProjects/timesfm-1.0-200m/checkpoints/.
Restored checkpoint in 4.69 seconds.
Jitting decoding.
ERROR:absl:For checkpoint version > 1.0, we require users to provide
          `train_state_unpadded_shape_dtype_struct` during checkpoint
          saving/restoring, to avoid potential silent bugs when loading
          checkpoints to incompatible unpadded shapes of TrainState.

Process finished with exit code 137 (interrupted by signal 9:SIGKILL)

I think this is due to the issue of lack of memory. Try again by giving your program more memory