I installed the dependencies based on the pyproject.toml file and getting the following error when initializing the MoiraiModule. Any suggestion for the right dependency of jaxtyping.
File "/home/ujp5kor/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code _run_code(code, mod_globals, init_globals, File "/home/ujp5kor/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code exec(code, run_globals) File "/home/ujp5kor/scratch/code/uni2ts/main.py", line 11, in <module> from uni2ts.model.moirai import MoiraiForecast, MoiraiModule File "/fs/scratch/rng_cr_bcai_dl/ujp5kor/code/uni2ts/src/uni2ts/model/moirai/__init__.py", line 16, in <module> from .finetune import MoiraiFinetune File "/fs/scratch/rng_cr_bcai_dl/ujp5kor/code/uni2ts/src/uni2ts/model/moirai/finetune.py", line 60, in <module> from .module import MoiraiModule File "/fs/scratch/rng_cr_bcai_dl/ujp5kor/code/uni2ts/src/uni2ts/model/moirai/module.py", line 25, in <module> from uni2ts.distribution import DistributionOutput File "/fs/scratch/rng_cr_bcai_dl/ujp5kor/code/uni2ts/src/uni2ts/distribution/__init__.py", line 16, in <module> from ._base import DistributionOutput, DistrParamProj File "/fs/scratch/rng_cr_bcai_dl/ujp5kor/code/uni2ts/src/uni2ts/distribution/_base.py", line 23, in <module> from jaxtyping import Float, Int, PyTree ImportError: cannot import name 'PyTree' from 'jaxtyping' (/home/ujp5kor/scratch/conda/envs/aim_404/lib/python3.11/site-packages/jaxtyping/__init__.py)
I installed the dependencies based on the pyproject.toml file and getting the following error when initializing the MoiraiModule. Any suggestion for the right dependency of jaxtyping.
File "/home/ujp5kor/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code _run_code(code, mod_globals, init_globals, File "/home/ujp5kor/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code exec(code, run_globals) File "/home/ujp5kor/scratch/code/uni2ts/main.py", line 11, in <module> from uni2ts.model.moirai import MoiraiForecast, MoiraiModule File "/fs/scratch/rng_cr_bcai_dl/ujp5kor/code/uni2ts/src/uni2ts/model/moirai/__init__.py", line 16, in <module> from .finetune import MoiraiFinetune File "/fs/scratch/rng_cr_bcai_dl/ujp5kor/code/uni2ts/src/uni2ts/model/moirai/finetune.py", line 60, in <module> from .module import MoiraiModule File "/fs/scratch/rng_cr_bcai_dl/ujp5kor/code/uni2ts/src/uni2ts/model/moirai/module.py", line 25, in <module> from uni2ts.distribution import DistributionOutput File "/fs/scratch/rng_cr_bcai_dl/ujp5kor/code/uni2ts/src/uni2ts/distribution/__init__.py", line 16, in <module> from ._base import DistributionOutput, DistrParamProj File "/fs/scratch/rng_cr_bcai_dl/ujp5kor/code/uni2ts/src/uni2ts/distribution/_base.py", line 23, in <module> from jaxtyping import Float, Int, PyTree ImportError: cannot import name 'PyTree' from 'jaxtyping' (/home/ujp5kor/scratch/conda/envs/aim_404/lib/python3.11/site-packages/jaxtyping/__init__.py)