Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
Apache License 2.0
28.35k stars 3.38k forks source link

typeerror when trying to fit a TemporalFusionTransformer model #17458

Closed jsejdija closed 1 year ago

jsejdija commented 1 year ago

Bug description

when trying to fit a TemporalFusionTransformer there is a typeerror.

What version are you seeing the problem on?


How to reproduce the bug

trainer = pl.Trainer(
    devices=1, accelerator="gpu",

tft = TemporalFusionTransformer.from_dataset(

Error messages and logs

typeerror: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TemporalFusionTransformer`


Current environment * CUDA: - GPU: - NVIDIA A100-PCIE-40GB - available: True - version: 11.7 * Lightning: - lightning: 2.0.1 - lightning-cloud: 0.5.32 - lightning-utilities: 0.8.0 - pytorch-forecasting: 1.0.0 - pytorch-lightning: 2.0.1.post0 - pytorch-optimizer: 2.5.1 - torch: 2.0.0 - torchmetrics: 0.11.4 * Packages: - absl-py: 1.4.0 - aiohttp: 3.8.4 - aiosignal: 1.3.1 - alembic: 1.10.3 - anyio: 3.6.2 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - asttokens: 2.2.1 - astunparse: 1.6.3 - async-timeout: 4.0.2 - attrs: 22.2.0 - backcall: 0.2.0 - backports.functools-lru-cache: 1.6.4 - beautifulsoup4: 4.12.2 - bleach: 6.0.0 - blessed: 1.20.0 - cachetools: 5.3.0 - certifi: 2022.12.7 - cffi: 1.15.1 - charset-normalizer: 3.1.0 - click: 8.1.3 - cmaes: 0.9.1 - cmake: 3.26.3 - colorlog: 6.7.0 - comm: 0.1.3 - contourpy: 1.0.7 - convertdate: 2.4.0 - croniter: 1.3.10 - cubinlinker: 0.2.2 - cuda-python: 11.8.1 - cudf: 23.4.0 - cupy: 11.6.0 - cycler: 0.11.0 - dateutils: 0.6.12 - debugpy: 1.6.7 - decorator: 5.1.1 - deepdiff: 6.3.0 - defusedxml: 0.7.1 - dnspython: 2.3.0 - email-validator: 1.3.1 - entrypoints: 0.4 - executing: 1.2.0 - fastapi: 0.88.0 - fastavro: 1.7.3 - fastjsonschema: 2.16.3 - fastrlock: 0.8 - filelock: 3.11.0 - flatbuffers: 23.3.3 - flit-core: 3.8.0 - fonttools: 4.39.3 - frozenlist: 1.3.3 - fsspec: 2023.4.0 - gast: 0.4.0 - google-auth: 2.17.1 - google-auth-oauthlib: 1.0.0 - google-pasta: 0.2.0 - greenlet: 2.0.2 - grpcio: 1.53.0 - h11: 0.14.0 - h5py: 3.8.0 - hijri-converter: 2.2.4 - holidays: 0.23 - httpcore: 0.17.0 - httptools: 0.5.0 - httpx: 0.24.0 - hupper: 1.12 - idna: 3.4 - importlib-metadata: 6.6.0 - importlib-resources: 5.12.0 - inquirer: 3.1.3 - ipykernel: 6.22.0 - ipython: 8.12.0 - ipython-genutils: 0.2.0 - ipywidgets: 8.0.6 - itsdangerous: 2.1.2 - jedi: 0.18.2 - jinja2: 3.1.2 - joblib: 1.2.0 - jsonschema: 4.17.3 - jupyter: 1.0.0 - jupyter-client: 8.2.0 - jupyter-console: 6.6.3 - jupyter-core: 5.3.0 - jupyter-events: 0.6.3 - jupyter-server: 2.5.0 - jupyter-server-terminals: 0.4.4 - jupyterlab-pygments: 0.2.2 - jupyterlab-widgets: 3.0.7 - kiwisolver: 1.4.4 - korean-lunar-calendar: 0.3.1 - libclang: 16.0.0 - lightning: 2.0.1 - lightning-cloud: 0.5.32 - lightning-utilities: 0.8.0 - lit: 16.0.1 - llvmlite: 0.39.1 - mako: 1.2.4 - markdown: 3.4.3 - markdown-it-py: 2.2.0 - markupsafe: 2.1.2 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mistune: 2.0.5 - mpmath: 1.3.0 - multidict: 6.0.4 - nbclassic: 0.5.5 - nbclient: 0.7.3 - nbconvert: 7.3.1 - nbformat: 5.8.0 - nest-asyncio: 1.5.6 - networkx: 3.1 - notebook: 6.5.4 - notebook-shim: 0.2.2 - numba: 0.56.4 - numpy: 1.23.5 - nvidia-cublas-cu11: - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: - nvidia-cufft-cu11: - nvidia-curand-cu11: - nvidia-cusolver-cu11: - nvidia-cusparse-cu11: - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - nvtx: 0.2.5 - oauthlib: 3.2.2 - opt-einsum: 3.3.0 - optuna: 3.1.1 - ordered-set: 4.1.0 - orjson: 3.8.10 - packaging: 23.1 - pandas: 1.5.3 - pandocfilters: 1.5.0 - parso: 0.8.3 - pastedeploy: 3.0.1 - patsy: 0.5.3 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.5.0 - pip: 23.1.1 - pkgutil-resolve-name: 1.3.10 - plaster: 1.0 - plaster-pastedeploy: 0.7 - platformdirs: 3.2.0 - ply: 3.11 - prometheus-client: 0.16.0 - prompt-toolkit: 3.0.38 - protobuf: 4.21.12 - psutil: 5.9.5 - ptxcompiler: 0.7.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 10.0.1 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycparser: 2.21 - pydantic: 1.10.7 - pygments: 2.15.1 - pyjwt: 2.6.0 - pymeeus: 0.5.12 - pyparsing: 3.0.9 - pyqt5: 5.15.7 - pyqt5-sip: 12.11.0 - pyramid: 2.0.1 - pyrsistent: 0.19.3 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-json-logger: 2.0.7 - python-multipart: 0.0.6 - pytorch-forecasting: 1.0.0 - pytorch-lightning: 2.0.1.post0 - pytorch-optimizer: 2.5.1 - pytz: 2023.3 - pyyaml: 6.0 - pyzmq: 25.0.2 - qtconsole: 5.4.2 - qtpy: 2.3.1 - rapids: 0.0.1 - readchar: 4.0.5 - requests: 2.28.2 - requests-oauthlib: 1.3.1 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.3.3 - rmm: 23.4.0 - rsa: 4.9 - scikit-learn: 1.2.2 - scipy: 1.10.1 - send2trash: 1.8.0 - setuptools: 67.7.1 - sip: 6.7.9 - six: 1.16.0 - sniffio: 1.3.0 - soupsieve: 2.3.2.post1 - sqlalchemy: 2.0.9 - stack-data: 0.6.2 - starlette: 0.22.0 - starsessions: 1.3.0 - statsmodels: 0.13.5 - sympy: 1.11.1 - tensorboard: 2.12.2 - tensorboard-data-server: 0.7.0 - tensorboard-plugin-wit: 1.8.1 - tensorflow-io-gcs-filesystem: 0.32.0 - termcolor: 2.2.0 - terminado: 0.17.1 - threadpoolctl: 3.1.0 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - torch: 2.0.0 - torchmetrics: 0.11.4 - tornado: 6.3 - tqdm: 4.65.0 - traitlets: 5.9.0 - translationstring: 1.4 - triton: 2.0.0 - typing-extensions: 4.5.0 - ujson: 5.7.0 - urllib3: 1.26.15 - uvicorn: 0.21.1 - uvloop: 0.17.0 - venusian: 3.0.0 - watchfiles: 0.19.0 - wcwidth: 0.2.6 - webencodings: 0.5.1 - webob: 1.8.7 - websocket-client: 1.5.1 - websockets: 11.0.1 - werkzeug: 2.2.3 - wheel: 0.40.0 - widgetsnbextension: 4.0.7 - yarl: 1.8.2 - zipp: 3.15.0 - zope.deprecation: 4.4.0 - zope.interface: 6.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.10 - version: #76-Ubuntu SMP Fri Mar 17 17:19:29 UTC 2023

More info

No response

arrow9577 commented 1 year ago

how to fix it? i also met this problem

hgersten5 commented 1 year ago

Was there a solution for this? I am running into the same error

ruuttt commented 10 months ago

You can find a solution here