automl / TabPFN

Official implementation of the TabPFN paper (https://arxiv.org/abs/2207.01848) and the tabpfn package.
http://priorlabs.ai
Apache License 2.0
1.22k stars 109 forks source link

NameError: name 'Module' is not defined #14

Closed Cryaaa closed 2 years ago

Cryaaa commented 2 years ago

Hello, I wanted to test your library on some of my data but I'm having trouble when I try to import the classifier with:

from tabpfn.scripts.transformer_prediction_interface import TabPFNClassifier

The full Traceback is:

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In [1], line 1
----> 1 from tabpfn.scripts.transformer_prediction_interface import TabPFNClassifier

File ~\OneDrive\Documents\PhD Jesse\TabPFN-main\tabpfn\scripts\transformer_prediction_interface.py:20
     18 from sklearn.utils import column_or_1d
     19 from pathlib import Path
---> 20 from tabpfn.scripts.model_builder import load_model
     21 import os
     22 import pickle

File ~\OneDrive\Documents\PhD Jesse\TabPFN-main\tabpfn\scripts\model_builder.py:2
      1 from functools import partial
----> 2 from tabpfn.train import train, Losses
      3 import tabpfn.priors as priors
      4 import tabpfn.encoders as encoders

File ~\OneDrive\Documents\PhD Jesse\TabPFN-main\tabpfn\train.py:14
     11 from torch import nn
     13 import tabpfn.utils as utils
---> 14 from tabpfn.transformer import TransformerModel
     15 from tabpfn.utils import get_cosine_schedule_with_warmup, get_openai_lr, StoreDictKeyPair, get_weighted_single_eval_pos_sampler, get_uniform_single_eval_pos_sampler
     16 import tabpfn.priors as priors

File ~\OneDrive\Documents\PhD Jesse\TabPFN-main\tabpfn\transformer.py:9
      6 from torch import Tensor
      7 from torch.nn import Module, TransformerEncoder
----> 9 from tabpfn.layer import TransformerEncoderLayer, _get_activation_fn
     10 from tabpfn.utils import SeqBN, bool_mask_to_att_mask
     14 class TransformerModel(nn.Module):

File ~\OneDrive\Documents\PhD Jesse\TabPFN-main\tabpfn\layer.py:10
      5 from torch.nn.modules.transformer import _get_activation_fn
      7 from torch.utils.checkpoint import checkpoint
---> 10 class TransformerEncoderLayer(Module):
     11     r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
     12     This standard encoder layer is based on the paper "Attention Is All You Need".
     13     Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
   (...)
     36         >>> out = encoder_layer(src)
     37     """
     38     __constants__ = ['batch_first']

NameError: name 'Module' is not defined

If it helps: the list of installed pip packages is:

Package                            Version
---------------------------------- -----------------
alabaster                          0.7.12
anaconda-client                    1.7.2
anaconda-navigator                 2.0.3
anaconda-project                   0.9.1
anyio                              2.2.0
appdirs                            1.4.4
argh                               0.26.2
argon2-cffi                        20.1.0
asn1crypto                         1.4.0
astroid                            2.5
astropy                            4.2.1
async-generator                    1.10
atomicwrites                       1.4.0
attrs                              20.3.0
autopep8                           1.5.6
Babel                              2.9.0
backcall                           0.2.0
backports.functools-lru-cache      1.6.4
backports.shutil-get-terminal-size 1.0.0
backports.tempfile                 1.0
backports.weakref                  1.0.post1
bcrypt                             3.2.0
beautifulsoup4                     4.9.3
bitarray                           1.9.2
bkcharts                           0.2
black                              19.10b0
bleach                             3.3.0
bokeh                              2.3.2
boto                               2.49.0
Bottleneck                         1.3.2
brotlipy                           0.7.0
certifi                            2020.12.5
cffi                               1.14.5
chardet                            4.0.0
click                              7.1.2
cloudpickle                        1.6.0
clyent                             1.2.2
colorama                           0.4.4
comtypes                           1.1.9
conda                              4.13.0
conda-build                        3.21.4
conda-content-trust                0+unknown
conda-package-handling             1.8.1
conda-repo-cli                     1.0.4
conda-token                        0.3.0
conda-verify                       3.4.2
ConfigSpace                        0.6.0
contextlib2                        0.6.0.post1
cryptography                       3.4.7
cycler                             0.10.0
Cython                             0.29.23
cytoolz                            0.11.0
dask                               2021.4.0
decorator                          5.0.6
defusedxml                         0.7.1
diff-match-patch                   20200713
distributed                        2021.4.0
docutils                           0.17
entrypoints                        0.3
et-xmlfile                         1.0.1
fastcache                          1.1.0
filelock                           3.0.12
flake8                             3.9.0
Flask                              1.1.2
fsspec                             0.9.0
future                             0.18.2
gevent                             21.1.2
glob2                              0.7
gpytorch                           1.9.0
greenlet                           1.0.0
h5py                               2.10.0
HeapDict                           1.0.1
html5lib                           1.1
hyperopt                           0.2.7
idna                               2.10
imagecodecs                        2021.3.31
imageio                            2.9.0
imagesize                          1.2.0
importlib-metadata                 3.10.0
iniconfig                          1.1.1
intervaltree                       3.1.0
ipykernel                          5.3.4
ipython                            7.22.0
ipython-genutils                   0.2.0
ipywidgets                         7.6.3
isort                              5.8.0
itsdangerous                       1.1.0
jdcal                              1.4.1
jedi                               0.17.2
Jinja2                             2.11.3
joblib                             1.0.1
json5                              0.9.5
jsonschema                         3.2.0
jupyter                            1.0.0
jupyter-client                     6.1.12
jupyter-console                    6.4.0
jupyter-core                       4.7.1
jupyter-packaging                  0.7.12
jupyter-server                     1.4.1
jupyterlab                         3.0.14
jupyterlab-pygments                0.1.2
jupyterlab-server                  2.4.0
jupyterlab-widgets                 1.0.0
keyring                            22.3.0
kiwisolver                         1.3.1
lazy-object-proxy                  1.6.0
liac-arff                          2.5.0
libarchive-c                       2.9
linear-operator                    0.1.1
llvmlite                           0.36.0
locket                             0.2.1
lxml                               4.6.3
MarkupSafe                         1.1.1
matplotlib                         3.3.4
mccabe                             0.6.1
menuinst                           1.4.16
minio                              7.1.12
mistune                            0.8.4
mkl-fft                            1.3.0
mkl-random                         1.2.1
mkl-service                        2.3.0
mock                               4.0.3
more-itertools                     8.7.0
mpmath                             1.2.1
msgpack                            1.0.2
multipledispatch                   0.6.0
mypy-extensions                    0.4.3
navigator-updater                  0.2.1
nbclassic                          0.2.6
nbclient                           0.5.3
nbconvert                          6.0.7
nbformat                           5.1.3
nest-asyncio                       1.5.1
networkx                           2.5
nltk                               3.6.1
nose                               1.3.7
notebook                           6.3.0
numba                              0.53.1
numexpr                            2.7.3
numpy                              1.22.4
numpydoc                           1.1.0
olefile                            0.46
openml                             0.12.2
openpyxl                           3.0.7
packaging                          20.9
pandas                             1.2.4
pandocfilters                      1.4.3
paramiko                           2.7.2
parso                              0.7.0
partd                              1.2.0
path                               15.1.2
pathlib2                           2.3.5
pathspec                           0.7.0
patsy                              0.5.1
pep8                               1.7.1
pexpect                            4.8.0
pickleshare                        0.7.5
Pillow                             8.2.0
pip                                21.0.1
pkginfo                            1.7.0
pluggy                             0.13.1
ply                                3.11
prometheus-client                  0.10.1
prompt-toolkit                     3.0.17
psutil                             5.8.0
ptyprocess                         0.7.0
py                                 1.10.0
py4j                               0.10.9.7
pyarrow                            10.0.0
pycodestyle                        2.6.0
pycosat                            0.6.3
pycparser                          2.20
pycurl                             7.43.0.6
pydocstyle                         6.0.0
pyerfa                             1.7.3
pyflakes                           2.2.0
Pygments                           2.8.1
pylint                             2.7.4
pyls-black                         0.4.6
pyls-spyder                        0.3.2
PyNaCl                             1.4.0
pyodbc                             4.0.0-unsupported
pyOpenSSL                          20.0.1
pyparsing                          2.4.7
pyreadline                         2.1
pyrsistent                         0.17.3
PySocks                            1.7.1
pytest                             6.2.3
python-dateutil                    2.8.1
python-jsonrpc-server              0.4.0
python-language-server             0.36.2
pytz                               2021.1
PyWavelets                         1.1.1
pywin32                            227
pywin32-ctypes                     0.2.0
pywinpty                           0.5.7
PyYAML                             5.4.1
pyzmq                              20.0.0
QDarkStyle                         2.8.1
QtAwesome                          1.0.2
qtconsole                          5.0.3
QtPy                               1.9.0
regex                              2021.4.4
requests                           2.25.1
rope                               0.18.0
Rtree                              0.9.7
ruamel-yaml-conda                  0.15.100
scikit-image                       0.18.1
scikit-learn                       1.1.3
scipy                              1.6.2
seaborn                            0.12.1
Send2Trash                         1.5.0
setuptools                         58.2.0
simplegeneric                      0.8.1
singledispatch                     0.0.0
sip                                4.19.13
six                                1.15.0
sniffio                            1.2.0
snowballstemmer                    2.1.0
sortedcollections                  2.1.0
sortedcontainers                   2.3.0
soupsieve                          2.2.1
Sphinx                             4.0.1
sphinxcontrib-applehelp            1.0.2
sphinxcontrib-devhelp              1.0.2
sphinxcontrib-htmlhelp             1.0.3
sphinxcontrib-jsmath               1.0.1
sphinxcontrib-qthelp               1.0.3
sphinxcontrib-serializinghtml      1.1.4
sphinxcontrib-websupport           1.2.4
spyder                             4.2.5
spyder-kernels                     1.10.2
SQLAlchemy                         1.4.7
statsmodels                        0.12.2
sympy                              1.8
tables                             3.6.1
tabpfn                             0.1.5
tblib                              1.7.0
terminado                          0.9.4
testpath                           0.4.4
textdistance                       4.2.1
threadpoolctl                      2.1.0
three-merge                        0.1.1
tifffile                           2021.4.8
toml                               0.10.2
toolz                              0.11.1
torch                              1.13.0
tornado                            6.1
tqdm                               4.64.1
traitlets                          5.0.5
twine                              3.4.2
typed-ast                          1.4.2
typing-extensions                  3.7.4.3
ujson                              4.0.2
unicodecsv                         0.14.1
urllib3                            1.26.4
watchdog                           1.0.2
wcwidth                            0.2.5
webencodings                       0.5.1
Werkzeug                           1.0.1
wheel                              0.37.0
widgetsnbextension                 3.5.1
win-inet-pton                      1.1.0
win-unicode-console                0.5
wincertstore                       0.2
wrapt                              1.12.1
xlrd                               2.0.1
XlsxWriter                         1.3.8
xlwings                            0.23.0
xlwt                               1.3.0
xmltodict                          0.12.0
yapf                               0.31.0
zict                               2.0.0
zipp                               3.4.1
zope.event                         4.5.0
zope.interface                     5.3.0

Maybe I'm missing something simple but I can't find out why its throwing this error...

fatihozturkh2o commented 2 years ago

Hey, I wanted to try out the library on my datasets as well and just rant into the exact same error.

fatihozturkh2o commented 2 years ago

In file tabpfn\layer.py, from torch.nn.modules.transformer import * is the tricky part. It's not importing the modules, and it's not a good practice for coding anyway. So the following would solve the issue:

import torch
from torch.nn.modules.transformer import _get_activation_fn, Module, Tensor, Optional, MultiheadAttention, Linear, Dropout, LayerNorm
from torch.utils.checkpoint import checkpoint
noahho commented 2 years ago

Thank you to both of you! I used the fix that you suggested @fatihozturkh2o

SamuelGabriel commented 2 years ago

In case you still care this fix is now on pip as well :)

Cryaaa commented 2 years ago

Nice one! I cloned the project already to test it but good to hear that it’s made it to a release so soon 😊