SortAnon / ControllableTalkNet

A web app that lets you play around with TalkNet models
GNU Affero General Public License v3.0
121 stars 48 forks source link

ImportError: cannot import name 'get_num_classes' from 'torchmetrics.utilities.data' on step 3 #35

Open GhostDog98 opened 1 year ago

GhostDog98 commented 1 year ago

Currently running the docker container on a linux environment, however when running step 3, it returns the following error:

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[4], line 4
      1 # Extract phoneme duration
      3 import json
----> 4 from nemo.collections.asr.models import EncDecCTCModel
      5 asr_model = EncDecCTCModel.from_pretrained(model_name="asr_talknet_aligner").cpu().eval()
      7 def forward_extractor(tokens, log_probs, blank):

File ~/anaconda3/envs/talknet/lib/python3.8/site-packages/nemo/collections/asr/__init__.py:15
      1 # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from nemo.collections.asr import data, losses, models, modules
     16 from nemo.package_info import __version__
     18 # Set collection version equal to NeMo version.

File ~/anaconda3/envs/talknet/lib/python3.8/site-packages/nemo/collections/asr/losses/__init__.py:15
      1 # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss
     16 from nemo.collections.asr.losses.audio_losses import SDRLoss
     17 from nemo.collections.asr.losses.ctc import CTCLoss

File ~/anaconda3/envs/talknet/lib/python3.8/site-packages/nemo/collections/asr/losses/angularloss.py:18
      1 # ! /usr/bin/python
      2 # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
      3 #
   (...)
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     16 import torch
---> 18 from nemo.core.classes import Loss, Typing, typecheck
     19 from nemo.core.neural_types import LabelsType, LogitsType, LossType, NeuralType
     21 __all__ = ['AngularSoftmaxLoss']

File ~/anaconda3/envs/talknet/lib/python3.8/site-packages/nemo/core/__init__.py:16
      1 # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     15 import nemo.core.neural_types
---> 16 from nemo.core.classes import *

File ~/anaconda3/envs/talknet/lib/python3.8/site-packages/nemo/core/classes/__init__.py:18
     16 import hydra
     17 import omegaconf
---> 18 import pytorch_lightning
     20 from nemo.core.classes.common import (
     21     FileIO,
     22     Model,
   (...)
     27     typecheck,
     28 )
     29 from nemo.core.classes.dataset import Dataset, IterableDataset

File ~/anaconda3/envs/talknet/lib/python3.8/site-packages/pytorch_lightning/__init__.py:20
     17 _PACKAGE_ROOT = os.path.dirname(__file__)
     18 _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
---> 20 from pytorch_lightning import metrics  # noqa: E402
     21 from pytorch_lightning.callbacks import Callback  # noqa: E402
     22 from pytorch_lightning.core import LightningDataModule, LightningModule  # noqa: E402

File ~/anaconda3/envs/talknet/lib/python3.8/site-packages/pytorch_lightning/metrics/__init__.py:15
      1 # Copyright The PyTorch Lightning team.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from pytorch_lightning.metrics.classification import (  # noqa: F401
     16     Accuracy,
     17     AUC,
     18     AUROC,
     19     AveragePrecision,
     20     ConfusionMatrix,
     21     F1,
     22     FBeta,
     23     HammingDistance,
     24     IoU,
     25     Precision,
     26     PrecisionRecallCurve,
     27     Recall,
     28     ROC,
     29     StatScores,
     30 )
     31 from pytorch_lightning.metrics.metric import Metric, MetricCollection  # noqa: F401
     32 from pytorch_lightning.metrics.regression import (  # noqa: F401
     33     ExplainedVariance,
     34     MeanAbsoluteError,
   (...)
     39     SSIM,
     40 )

File ~/anaconda3/envs/talknet/lib/python3.8/site-packages/pytorch_lightning/metrics/classification/__init__.py:14
      1 # Copyright The PyTorch Lightning team.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 14 from pytorch_lightning.metrics.classification.accuracy import Accuracy  # noqa: F401
     15 from pytorch_lightning.metrics.classification.auc import AUC  # noqa: F401
     16 from pytorch_lightning.metrics.classification.auroc import AUROC  # noqa: F401

File ~/anaconda3/envs/talknet/lib/python3.8/site-packages/pytorch_lightning/metrics/classification/accuracy.py:18
     14 from typing import Any, Callable, Optional
     16 from torchmetrics import Accuracy as _Accuracy
---> 18 from pytorch_lightning.metrics.utils import deprecated_metrics
     21 class Accuracy(_Accuracy):
     23     @deprecated_metrics(target=_Accuracy)
     24     def __init__(
     25         self,
   (...)
     32         dist_sync_fn: Callable = None,
     33     ):

File ~/anaconda3/envs/talknet/lib/python3.8/site-packages/pytorch_lightning/metrics/utils.py:22
     20 from torchmetrics.utilities.data import dim_zero_mean as _dim_zero_mean
     21 from torchmetrics.utilities.data import dim_zero_sum as _dim_zero_sum
---> 22 from torchmetrics.utilities.data import get_num_classes as _get_num_classes
     23 from torchmetrics.utilities.data import select_topk as _select_topk
     24 from torchmetrics.utilities.data import to_categorical as _to_categorical

ImportError: cannot import name 'get_num_classes' from 'torchmetrics.utilities.data' (/home/ghostdog/anaconda3/envs/talknet/lib/python3.8/site-packages/torchmetrics/utilities/data.py)

I've tried re-installing torchmetrics version 0.6.0 using the command conda install -c conda-forge torchmetrics=0.6.0 What can I do to remedy this?

GhostDog98 commented 1 year ago

hmm... it seems like the issue is with the fact that pytorch-lightning==1.3.8 is not installed during the setup, but instead a different version (fixed by simply running pip install pytorch-lightning==1.3.8)... Still getting an error though, this time: ImportError: /home/ghostdog/anaconda3/envs/talknet/lib/python3.8/site-packages/torchtext/_torchtext.so: undefined symbol: _ZNK3c104Type14isSubtypeOfExtERKSt10shared_ptrIS0_EPSo