hspark1212 / MOFTransformer

Universal Transfer Learning in Porous Materials, including MOFs.
https://hspark1212.github.io/MOFTransformer/
86 stars 13 forks source link

torchmetrics version error #140

Closed MINGUUUS closed 1 year ago

MINGUUUS commented 1 year ago

Hello! I found a version error of torchmetrics.

After installing moftransformer, I typed from moftransformer.utils.download import download_pretrain_model as following documentation.

However, I have some error like below.

>>> ipython

In [1]: from moftransformer.utils.download import download_pretrain_model

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[1], line 1
----> 1 from moftransformer.utils.download import download_pretrain_model

File ~/MOFTransformer/moftransformer/__init__.py:7
      4 __version__ = "2.1.1"
      5 __root_dir__ = os.path.dirname(__file__)
----> 7 from moftransformer import visualize, utils, modules, libs, gadgets, datamodules, assets
      8 from moftransformer.run import run
      9 from moftransformer.predict import predict

File ~/MOFTransformer/moftransformer/visualize/__init__.py:2
      1 # MOFTransformer version 2.0.0
----> 2 from moftransformer.visualize.visualizer import PatchVisualizer
      4 __all__ = ["PatchVisualizer"]

File ~/MOFTransformer/moftransformer/visualize/visualizer.py:12
      9 from matplotlib.colors import ListedColormap
     10 from matplotlib import animation
---> 12 from moftransformer.visualize.utils import (
     13     get_structure,
     14     get_heatmap,
     15     scaler,
     16     get_model_and_datamodule,
     17     get_batch_from_index,
     18     get_batch_from_cif_id,
     19 )
     20 from moftransformer.visualize.setting import (
     21     get_fig_ax,
     22     set_fig_ax,
   (...)
     27     get_cmap,
     28 )
     29 from moftransformer.visualize.drawer import (
     30     draw_cell,
     31     draw_atoms,
   (...)
     34     draw_heatmap_graph,
     35 )

File ~/MOFTransformer/moftransformer/visualize/utils.py:16
     14 from pymatgen.io.ase import AseAtomsAdaptor
     15 import torch
---> 16 import pytorch_lightning as pl
     17 from moftransformer.modules.module import Module
     18 from moftransformer.datamodules.datamodule import Datamodule

File ~/anaconda3/envs/PMTransformer/lib/python3.9/site-packages/pytorch_lightning/__init__.py:34
     31     _logger.addHandler(logging.StreamHandler())
     32     _logger.propagate = False
---> 34 from pytorch_lightning.callbacks import Callback  # noqa: E402
     35 from pytorch_lightning.core import LightningDataModule, LightningModule  # noqa: E402
     36 from pytorch_lightning.trainer import Trainer  # noqa: E402

File ~/anaconda3/envs/PMTransformer/lib/python3.9/site-packages/pytorch_lightning/callbacks/__init__.py:25
     23 from pytorch_lightning.callbacks.model_summary import ModelSummary
     24 from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
---> 25 from pytorch_lightning.callbacks.progress import ProgressBarBase, RichProgressBar, TQDMProgressBar
     26 from pytorch_lightning.callbacks.pruning import ModelPruning
     27 from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining

File ~/anaconda3/envs/PMTransformer/lib/python3.9/site-packages/pytorch_lightning/callbacks/progress/__init__.py:22
     14 """
     15 Progress Bars
     16 =============
   (...)
     19
     20 """
     21 from pytorch_lightning.callbacks.progress.base import ProgressBarBase  # noqa: F401
---> 22 from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar  # noqa: F401
     23 from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar  # noqa: F401

File ~/anaconda3/envs/PMTransformer/lib/python3.9/site-packages/pytorch_lightning/callbacks/progress/rich_progress.py:20
     17 from datetime import timedelta
     18 from typing import Any, Dict, Optional, Union
---> 20 from torchmetrics.utilities.imports import _compare_version
     22 import pytorch_lightning as pl
     23 from pytorch_lightning.callbacks.progress.base import ProgressBarBase

ImportError: cannot import name '_compare_version' from 'torchmetrics.utilities.imports' (/home/mingyu/anaconda3/envs/PMTransformer/lib/python3.9/site-packages/torchmetrics/utilities/imports.py)

Current version of torchmetrics was 1.0.0, but downgrading the version of torchmetrics solved this error.

pip uninstall torchmetrics
pip install torchmetrics==0.11.1

It would be helpful to fix the version in the requirements.txt file.

I always appreciate your awesome research!

Mingyu Jeon

Yeonghun1675 commented 1 year ago

Hi, @MINGUUUS!

Thank you for letting us know about the fix. We'll make sure to fix it as you mentioned. Let us know if you run into any additional issues.