Lightning-Universe / lightning-transformers

Flexible components pairing 🤗 Transformers with :zap: Pytorch Lightning
https://lightning-transformers.readthedocs.io
Apache License 2.0
607 stars 77 forks source link

How can we specify `num_classes` for the model? #215

Closed jmwoloso closed 2 years ago

jmwoloso commented 2 years ago

❓ Questions and Help

Before asking:

  1. search the issues.
  2. search the docs.

What is your question?

How do I specify num_classes via running the CLI on a local clone of the repo?

Code

the issue happens during the call to configure_metrics in task/nlp/text_classification/model.py

What have you tried?

currently, I've hard-coded the value in configure_metrics

    def configure_metrics(self, _) -> None:
        self.prec = Precision(num_classes=2)
        self.recall = Recall(num_classes=2)
        self.acc = Accuracy()
        self.metrics = {"precision": self.prec, "recall": self.recall, "accuracy": self.acc}

What's your environment?

OS: Linux Mint 19.3 Conda (environment.yml):

name: pml
channels:
  - pytorch
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=1_llvm
  - abseil-cpp=20210324.2=h9c3ff4c_0
  - absl-py=0.13.0=py38h06a4308_0
  - aiohttp=3.8.1=py38h7f8727e_0
  - aiosignal=1.2.0=pyhd3eb1b0_0
  - arrow-cpp=3.0.0=py38h6b21186_4
  - async-timeout=4.0.1=pyhd3eb1b0_0
  - attrs=21.2.0=pyhd3eb1b0_0
  - aws-c-common=0.4.57=he6710b0_1
  - aws-c-event-stream=0.1.6=h2531618_5
  - aws-checksums=0.1.9=he6710b0_0
  - aws-sdk-cpp=1.8.185=hce553d0_0
  - backcall=0.2.0=pyhd3eb1b0_0
  - blas=1.0=mkl
  - blinker=1.4=py38h06a4308_0
  - boost-cpp=1.69.0=h11c811c_1000
  - boto3=1.18.21=pyhd3eb1b0_0
  - botocore=1.21.41=pyhd3eb1b0_1
  - brotli=1.0.9=h7f98852_6
  - brotli-bin=1.0.9=h7f98852_6
  - brotlipy=0.7.0=py38h27cfd23_1003
  - bzip2=1.0.8=h7b6447c_0
  - c-ares=1.17.1=h27cfd23_0
  - ca-certificates=2021.10.8=ha878542_0
  - certifi=2021.10.8=py38h578d9bd_1
  - cffi=1.14.6=py38h400218f_0
  - cryptography=3.4.8=py38hd23ed53_0
  - cudatoolkit=11.1.1=h6406543_9
  - dataclasses=0.8=pyh6d0b6a4_7
  - datasets=1.16.1=pyhd8ed1ab_0
  - debugpy=1.5.1=py38h295c915_0
  - decorator=5.1.0=pyhd3eb1b0_0
  - dill=0.3.4=pyhd8ed1ab_0
  - double-conversion=3.1.6=h9c3ff4c_0
  - ffmpeg=4.2.2=h20bf706_0
  - filelock=3.4.0=pyhd8ed1ab_0
  - freetype=2.11.0=h70c0345_0
  - frozenlist=1.2.0=py38h7f8727e_0
  - fsspec=2021.10.1=pyhd3eb1b0_0
  - future=0.18.2=py38_1
  - gflags=2.2.2=he1b5a44_1004
  - giflib=5.2.1=h7b6447c_0
  - glog=0.5.0=h48cff8f_0
  - gmp=6.2.1=h2531618_2
  - gnutls=3.6.15=he1e5248_0
  - grpc-cpp=1.39.0=hae934f6_5
  - grpcio=1.42.0=py38hce63b2e_0
  - huggingface_hub=0.2.1=pyhd8ed1ab_0
  - icu=58.2=hf484d3e_1000
  - idna=3.3=pyhd3eb1b0_0
  - importlib-metadata=4.8.2=py38h06a4308_0
  - importlib_metadata=4.8.2=hd8ed1ab_0
  - intel-openmp=2021.4.0=h06a4308_3561
  - ipython=7.29.0=py38hb070fc8_0
  - ipython_genutils=0.2.0=pyhd3eb1b0_1
  - jmespath=0.10.0=pyhd3eb1b0_0
  - joblib=1.1.0=pyhd3eb1b0_0
  - jpeg=9d=h7f8727e_0
  - jupyter_client=7.0.6=pyhd3eb1b0_0
  - jupyter_core=4.9.1=py38h06a4308_0
  - krb5=1.19.2=hcc1bbae_3
  - lame=3.100=h7b6447c_0
  - lcms2=2.12=h3be6417_0
  - ld_impl_linux-64=2.35.1=h7274673_9
  - libboost=1.73.0=h3ff78a5_11
  - libbrotlicommon=1.0.9=h7f98852_6
  - libbrotlidec=1.0.9=h7f98852_6
  - libbrotlienc=1.0.9=h7f98852_6
  - libcurl=7.78.0=h0b77cf5_0
  - libedit=3.1.20210910=h7f8727e_0
  - libev=4.33=h516909a_1
  - libevent=2.1.10=h9b69904_4
  - libffi=3.3=he6710b0_2
  - libgcc-ng=11.2.0=h1d223b6_11
  - libidn2=2.3.2=h7f8727e_0
  - libnghttp2=1.43.0=h812cca2_0
  - libopus=1.3.1=h7b6447c_0
  - libpng=1.6.37=hbc83047_0
  - libprotobuf=3.17.2=h4ff587b_1
  - libsodium=1.0.18=h7b6447c_0
  - libssh2=1.10.0=ha56f1ee_2
  - libstdcxx-ng=11.2.0=he4da1e4_11
  - libtasn1=4.16.0=h27cfd23_0
  - libthrift=0.14.2=hcc01f38_0
  - libtiff=4.2.0=h85742a9_0
  - libunistring=0.9.10=h27cfd23_0
  - libuv=1.40.0=h7b6447c_0
  - libvpx=1.7.0=h439df22_0
  - libwebp=1.2.0=h89dd481_0
  - libwebp-base=1.2.0=h27cfd23_0
  - llvm-openmp=12.0.1=h4bd325d_1
  - lz4-c=1.9.3=h295c915_1
  - markdown=3.3.4=py38h06a4308_0
  - mkl=2021.4.0=h06a4308_640
  - mkl-service=2.4.0=py38h7f8727e_0
  - mkl_fft=1.3.1=py38hd3c417c_0
  - mkl_random=1.2.2=py38h51133e4_0
  - multidict=5.1.0=py38h27cfd23_2
  - multiprocess=0.70.12.2=py38h497a2fe_1
  - ncurses=6.3=h7f8727e_2
  - nest-asyncio=1.5.1=pyhd3eb1b0_0
  - nettle=3.7.3=hbbd107a_1
  - numpy-base=1.21.2=py38h79a1101_0
  - oauthlib=3.1.1=pyhd3eb1b0_0
  - olefile=0.46=pyhd3eb1b0_0
  - openh264=2.1.0=hd408876_0
  - openssl=1.1.1l=h7f98852_0
  - orc=1.6.9=ha97a36c_3
  - packaging=21.3=pyhd3eb1b0_0
  - parso=0.8.2=pyhd3eb1b0_0
  - pexpect=4.8.0=pyhd3eb1b0_3
  - pickleshare=0.7.5=pyhd3eb1b0_1003
  - pip=21.2.4=py38h06a4308_0
  - ptyprocess=0.7.0=pyhd3eb1b0_2
  - pyasn1=0.4.8=pyhd3eb1b0_0
  - pycparser=2.21=pyhd3eb1b0_0
  - pydeprecate=0.3.1=pyhd8ed1ab_0
  - pygments=2.10.0=pyhd3eb1b0_0
  - pyparsing=3.0.4=pyhd3eb1b0_0
  - pysocks=1.7.1=py38h06a4308_0
  - python=3.8.12=h12debd9_0
  - python-dateutil=2.8.2=pyhd3eb1b0_0
  - python-xxhash=2.0.2=py38h497a2fe_1
  - python_abi=3.8=2_cp38
  - pytorch=1.10.0=py3.8_cuda11.1_cudnn8.0.5_0
  - pytorch-lightning=1.5.5=pyhd8ed1ab_0
  - pytorch-mutex=1.0=cuda
  - pytz=2021.3=pyhd8ed1ab_0
  - pyyaml=6.0=py38h7f8727e_1
  - pyzmq=22.3.0=py38h295c915_2
  - re2=2021.11.01=h9c3ff4c_0
  - readline=8.1=h27cfd23_0
  - regex=2021.8.3=py38h7f8727e_0
  - requests=2.26.0=pyhd3eb1b0_0
  - requests-oauthlib=1.3.0=py_0
  - rsa=4.7.2=pyhd3eb1b0_1
  - s3transfer=0.5.0=pyhd3eb1b0_0
  - sacremoses=0.0.43=pyhd3eb1b0_0
  - setuptools=58.0.4=py38h06a4308_0
  - six=1.16.0=pyhd3eb1b0_0
  - snappy=1.1.8=he1b5a44_3
  - sqlite=3.36.0=hc218d9a_0
  - tk=8.6.11=h1ccaba5_0
  - tokenizers=0.10.3=py38hb63a372_1
  - torchaudio=0.10.0=py38_cu111
  - torchmetrics=0.6.1=pyhd8ed1ab_0
  - torchvision=0.11.1=py38_cu111
  - tornado=6.1=py38h27cfd23_0
  - tqdm=4.62.3=pyhd3eb1b0_1
  - traitlets=5.1.1=pyhd3eb1b0_0
  - transformers=4.11.3=pyhd8ed1ab_0
  - typing-extensions=3.10.0.2=hd3eb1b0_0
  - typing_extensions=3.10.0.2=pyh06a4308_0
  - uriparser=0.9.5=h9c3ff4c_0
  - utf8proc=2.6.1=h27cfd23_0
  - wcwidth=0.2.5=pyhd3eb1b0_0
  - werkzeug=2.0.2=pyhd3eb1b0_0
  - wheel=0.37.0=pyhd3eb1b0_1
  - x264=1!157.20191217=h7b6447c_0
  - xxhash=0.8.0=h7f98852_3
  - xz=5.2.5=h7b6447c_0
  - yaml=0.2.5=h7b6447c_0
  - yarl=1.6.3=py38h27cfd23_0
  - zeromq=4.3.4=h2531618_0
  - zipp=3.6.0=pyhd3eb1b0_0
  - zlib=1.2.11=h7b6447c_3
  - zstd=1.4.9=haebb681_0
  - pip:
    - adal==1.2.7
    - antlr4-python3-runtime==4.8
    - applicationinsights==0.11.10
    - astunparse==1.6.3
    - azure-common==1.1.27
    - azure-core==1.20.1
    - azure-graphrbac==0.61.1
    - azure-identity==1.7.0
    - azure-mgmt-authorization==0.61.0
    - azure-mgmt-containerregistry==8.2.0
    - azure-mgmt-core==1.3.0
    - azure-mgmt-keyvault==9.3.0
    - azure-mgmt-resource==13.0.0
    - azure-mgmt-storage==11.2.0
    - azureml-core==1.36.0.post2
    - azureml-dataprep==2.24.4
    - azureml-dataprep-native==38.0.0
    - azureml-dataprep-rslex==2.0.3
    - azureml-dataset-runtime==1.36.0
    - azureml-defaults==1.36.0
    - azureml-inference-server-http==0.4.2
    - azureml-mlflow==1.36.0
    - azureml-telemetry==1.36.0
    - backports-tempfile==1.0
    - backports-weakref==1.0.post1
    - cachetools==4.2.4
    - charset-normalizer==2.0.7
    - click==8.0.3
    - cloudpickle==2.0.0
    - configparser==3.7.4
    - contextlib2==21.6.0
    - cycler==0.11.0
    - databricks-cli==0.16.2
    - deepspeed==0.5.8
    - distro==1.6.0
    - docker==5.0.3
    - dotnetcore2==2.1.21
    - entrypoints==0.3
    - flask==1.0.3
    - flatbuffers==2.0
    - fonttools==4.28.1
    - fusepy==3.0.1
    - gast==0.4.0
    - gitdb==4.0.9
    - gitpython==3.1.24
    - google-auth==2.3.3
    - google-auth-oauthlib==0.4.6
    - google-pasta==0.2.0
    - gunicorn==20.1.0
    - h5py==3.6.0
    - hjson==3.0.2
    - horovod==0.23.0
    - hydra-core==1.1.0
    - importlib-resources==5.4.0
    - inference-schema==1.3.0
    - ipykernel==6.5.1
    - isodate==0.6.0
    - itsdangerous==2.0.1
    - jedi==0.18.1
    - jeepney==0.7.1
    - jinja2==3.0.3
    - json-logging-py==0.2
    - jsonpickle==2.0.0
    - keras==2.7.0
    - keras-preprocessing==1.1.2
    - kiwisolver==1.3.2
    - libclang==12.0.0
    - lightgbm==3.3.1
    - markupsafe==2.0.1
    - matplotlib==3.5.0
    - matplotlib-inline==0.1.3
    - mlflow-skinny==1.21.0
    - msal==1.16.0
    - msal-extensions==0.3.0
    - msrest==0.6.21
    - msrestazure==0.6.4
    - ndg-httpsclient==0.5.1
    - ninja==1.10.2.3
    - numpy==1.21.4
    - omegaconf==2.1.1
    - onnxruntime-gpu==1.9.0
    - opt-einsum==3.3.0
    - pandas==1.3.4
    - pathspec==0.9.0
    - pillow==8.4.0
    - plotly==5.4.0
    - portalocker==1.7.1
    - prompt-toolkit==3.0.22
    - protobuf==3.19.1
    - psutil==5.8.0
    - pyarrow==3.0.0
    - pyasn1-modules==0.2.8
    - pyjwt==2.3.0
    - pyopenssl==20.0.1
    - scikit-learn==1.0.1
    - scipy==1.7.2
    - secretstorage==3.3.1
    - setuptools-scm==6.3.2
    - smmap==5.0.0
    - tabulate==0.8.9
    - tenacity==8.0.1
    - tensorboard==2.7.0
    - tensorboard-data-server==0.6.1
    - tensorboard-plugin-wit==1.8.0
    - tensorflow-estimator==2.7.0
    - tensorflow-gpu==2.7.0
    - tensorflow-io-gcs-filesystem==0.22.0
    - termcolor==1.1.0
    - threadpoolctl==3.0.0
    - tomli==1.2.2
    - torch-tb-profiler==0.3.1
    - triton==1.1.1
    - urllib3==1.26.7
    - websocket-client==1.2.1
    - wrapt==1.13.3
prefix: /anaconda/envs/pml

requirements.txt:

adal==1.2.7
antlr4-python3-runtime==4.8
applicationinsights==0.11.10
astunparse==1.6.3
azure-common==1.1.27
azure-core==1.20.1
azure-graphrbac==0.61.1
azure-identity==1.7.0
azure-mgmt-authorization==0.61.0
azure-mgmt-containerregistry==8.2.0
azure-mgmt-core==1.3.0
azure-mgmt-keyvault==9.3.0
azure-mgmt-resource==13.0.0
azure-mgmt-storage==11.2.0
azureml-core==1.36.0.post2
azureml-dataprep==2.24.4
azureml-dataprep-native==38.0.0
azureml-dataprep-rslex==2.0.3
azureml-dataset-runtime==1.36.0
azureml-defaults==1.36.0
azureml-inference-server-http==0.4.2
azureml-mlflow==1.36.0
azureml-telemetry==1.36.0
backports-tempfile==1.0
backports-weakref==1.0.post1
cachetools==4.2.4
charset-normalizer==2.0.7
click==8.0.3
cloudpickle==2.0.0
configparser==3.7.4
contextlib2==21.6.0
cycler==0.11.0
databricks-cli==0.16.2
deepspeed==0.5.8
distro==1.6.0
docker==5.0.3
dotnetcore2==2.1.21
entrypoints==0.3
flask==1.0.3
flatbuffers==2.0
fonttools==4.28.1
fusepy==3.0.1
gast==0.4.0
gitdb==4.0.9
gitpython==3.1.24
google-auth==2.3.3
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
gunicorn==20.1.0
h5py==3.6.0
hjson==3.0.2
horovod==0.23.0
hydra-core==1.1.0
importlib-resources==5.4.0
inference-schema==1.3.0
ipykernel==6.5.1
isodate==0.6.0
itsdangerous==2.0.1
jedi==0.18.1
jeepney==0.7.1
jinja2==3.0.3
json-logging-py==0.2
jsonpickle==2.0.0
keras==2.7.0
keras-preprocessing==1.1.2
kiwisolver==1.3.2
libclang==12.0.0
lightgbm==3.3.1
markupsafe==2.0.1
matplotlib==3.5.0
matplotlib-inline==0.1.3
mlflow-skinny==1.21.0
msal==1.16.0
msal-extensions==0.3.0
msrest==0.6.21
msrestazure==0.6.4
ndg-httpsclient==0.5.1
ninja==1.10.2.3
numpy==1.21.4
omegaconf==2.1.1
onnxruntime-gpu==1.9.0
opt-einsum==3.3.0
pandas==1.3.4
pathspec==0.9.0
pillow==8.4.0
plotly==5.4.0
portalocker==1.7.1
prompt-toolkit==3.0.22
protobuf==3.19.1
psutil==5.8.0
pyarrow==3.0.0
pyasn1-modules==0.2.8
pyjwt==2.3.0
pyopenssl==20.0.1
scikit-learn==1.0.1
scipy==1.7.2
secretstorage==3.3.1
setuptools-scm==6.3.2
smmap==5.0.0
tabulate==0.8.9
tenacity==8.0.1
tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow-estimator==2.7.0
tensorflow-gpu==2.7.0
tensorflow-io-gcs-filesystem==0.22.0
termcolor==1.1.0
threadpoolctl==3.0.0
tomli==1.2.2
torch-tb-profiler==0.3.1
triton==1.1.1
urllib3==1.26.7
websocket-client==1.2.1
wrapt==1.13.3
deepspeed==0.5.8

related to https://github.com/PyTorchLightning/lightning-transformers/issues/154

jmwoloso commented 2 years ago

this is related to https://github.com/PyTorchLightning/lightning-transformers/issues/216 and a result of not having all the necessary props and methods implemented.