google-research / timesfm

TimesFM (Time Series Foundation Model) is a pretrained time-series foundation model developed by Google Research for time-series forecasting.
https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/
Apache License 2.0
3.76k stars 320 forks source link

Install error "ModuleNotFoundError: No module named 'jax'" #172

Open linpingta opened 1 week ago

linpingta commented 1 week ago

Hi guys,

If I simply install the lib with "pip install timesfm" and try the example code described in https://huggingface.co/google/timesfm-1.0-200m:

import timesfm

tfm = timesfm.TimesFm(
    context_len=14,
    horizon_len=7,
    input_patch_len=32,
    output_patch_len=128,
    num_layers=20,
    model_dims=1280
)
tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")

It will return an error:

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/timesfm/timesfm_base.py:27
     23 import pandas as pd
     25 from utilsforecast.processing import make_future_dataframe
---> 27 from . import xreg_lib
     29 Category = xreg_lib.Category
     30 XRegMode = xreg_lib.XRegMode

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/timesfm/xreg_lib.py:20
     17 import math
     18 from typing import Any, Iterable, Literal, Mapping, Sequence
---> 20 import jax
     21 import jax.numpy as jnp
     22 import numpy as np

ModuleNotFoundError: No module named 'jax'

Then I try to install jax manually, but it will meet another error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[12], line 3
      1 import timesfm
----> 3 tfm = timesfm.TimesFm(
      4     context_len=14,
      5     horizon_len=7,
      6     input_patch_len=32,
      7     output_patch_len=128,
      8     num_layers=20,
      9     model_dims=1280
     10 )
     11 tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")

TypeError: TimesFmBase.__init__() got an unexpected keyword argument 'context_len'

Could you help for what's wrong happened? thanks.

aleksmaksimovic commented 1 week ago

Create a new venv in python 3.10.x and try to install it that way:

pip install timesfm[pax]

linpingta commented 1 week ago

Thanks @aleksmaksimovic , I think the example inside https://huggingface.co/google/timesfm-1.0-200m may need update because TimesFm doesn't support directly parameter input, but should wrap it with TimesFmHparams.

From code:

class TimesFmBase:
  """Base TimesFM forecast API for inference.

  This class is the scaffolding for calling TimesFM forecast. To properly use:
    1. Create an instance with the correct hyperparameters of a TimesFM model.
    2. Call `load_from_checkpoint` to load a compatible checkpoint.
    3. Call `forecast` for inference.
  """

  def _logging(self, s):
    print(s)

  def __post_init__(self) -> None:
    """Additional initialization for subclasses before checkpoint loading."""
    pass

  def __init__(self, hparams: TimesFmHparams,
               checkpoint: TimesFmCheckpoint) -> None:
    """Initializes the TimesFM forecast API.

Correct me if I am wrong, thanks