Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.35k stars 3.38k forks source link

typeerror when trying to fit a TemporalFusionTransformer model #17458

Closed jsejdija closed 1 year ago

jsejdija commented 1 year ago

Bug description

when trying to fit a TemporalFusionTransformer there is a typeerror.

What version are you seeing the problem on?

2.0+

How to reproduce the bug

trainer = pl.Trainer(
    max_epochs=10,
    devices=1, accelerator="gpu",
    enable_model_summary=True,
    gradient_clip_val=0.25,
    limit_train_batches=10
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    lstm_layers=1,
    hidden_size=16,
    attention_head_size=2,
    dropout=0.2,
    hidden_continuous_size=8,
    output_size=1,
    loss=SMAPE(),
    log_interval=10,
    reduce_on_plateau_patience=4
)

trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

Error messages and logs

typeerror: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TemporalFusionTransformer`

Environment

Current environment * CUDA: - GPU: - NVIDIA A100-PCIE-40GB - available: True - version: 11.7 * Lightning: - lightning: 2.0.1 - lightning-cloud: 0.5.32 - lightning-utilities: 0.8.0 - pytorch-forecasting: 1.0.0 - pytorch-lightning: 2.0.1.post0 - pytorch-optimizer: 2.5.1 - torch: 2.0.0 - torchmetrics: 0.11.4 * Packages: - absl-py: 1.4.0 - aiohttp: 3.8.4 - aiosignal: 1.3.1 - alembic: 1.10.3 - anyio: 3.6.2 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - asttokens: 2.2.1 - astunparse: 1.6.3 - async-timeout: 4.0.2 - attrs: 22.2.0 - backcall: 0.2.0 - backports.functools-lru-cache: 1.6.4 - beautifulsoup4: 4.12.2 - bleach: 6.0.0 - blessed: 1.20.0 - cachetools: 5.3.0 - certifi: 2022.12.7 - cffi: 1.15.1 - charset-normalizer: 3.1.0 - click: 8.1.3 - cmaes: 0.9.1 - cmake: 3.26.3 - colorlog: 6.7.0 - comm: 0.1.3 - contourpy: 1.0.7 - convertdate: 2.4.0 - croniter: 1.3.10 - cubinlinker: 0.2.2 - cuda-python: 11.8.1 - cudf: 23.4.0 - cupy: 11.6.0 - cycler: 0.11.0 - dateutils: 0.6.12 - debugpy: 1.6.7 - decorator: 5.1.1 - deepdiff: 6.3.0 - defusedxml: 0.7.1 - dnspython: 2.3.0 - email-validator: 1.3.1 - entrypoints: 0.4 - executing: 1.2.0 - fastapi: 0.88.0 - fastavro: 1.7.3 - fastjsonschema: 2.16.3 - fastrlock: 0.8 - filelock: 3.11.0 - flatbuffers: 23.3.3 - flit-core: 3.8.0 - fonttools: 4.39.3 - frozenlist: 1.3.3 - fsspec: 2023.4.0 - gast: 0.4.0 - google-auth: 2.17.1 - google-auth-oauthlib: 1.0.0 - google-pasta: 0.2.0 - greenlet: 2.0.2 - grpcio: 1.53.0 - h11: 0.14.0 - h5py: 3.8.0 - hijri-converter: 2.2.4 - holidays: 0.23 - httpcore: 0.17.0 - httptools: 0.5.0 - httpx: 0.24.0 - hupper: 1.12 - idna: 3.4 - importlib-metadata: 6.6.0 - importlib-resources: 5.12.0 - inquirer: 3.1.3 - ipykernel: 6.22.0 - ipython: 8.12.0 - ipython-genutils: 0.2.0 - ipywidgets: 8.0.6 - itsdangerous: 2.1.2 - jedi: 0.18.2 - jinja2: 3.1.2 - joblib: 1.2.0 - jsonschema: 4.17.3 - jupyter: 1.0.0 - jupyter-client: 8.2.0 - jupyter-console: 6.6.3 - jupyter-core: 5.3.0 - jupyter-events: 0.6.3 - jupyter-server: 2.5.0 - jupyter-server-terminals: 0.4.4 - jupyterlab-pygments: 0.2.2 - jupyterlab-widgets: 3.0.7 - kiwisolver: 1.4.4 - korean-lunar-calendar: 0.3.1 - libclang: 16.0.0 - lightning: 2.0.1 - lightning-cloud: 0.5.32 - lightning-utilities: 0.8.0 - lit: 16.0.1 - llvmlite: 0.39.1 - mako: 1.2.4 - markdown: 3.4.3 - markdown-it-py: 2.2.0 - markupsafe: 2.1.2 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mistune: 2.0.5 - mpmath: 1.3.0 - multidict: 6.0.4 - nbclassic: 0.5.5 - nbclient: 0.7.3 - nbconvert: 7.3.1 - nbformat: 5.8.0 - nest-asyncio: 1.5.6 - networkx: 3.1 - notebook: 6.5.4 - notebook-shim: 0.2.2 - numba: 0.56.4 - numpy: 1.23.5 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - nvtx: 0.2.5 - oauthlib: 3.2.2 - opt-einsum: 3.3.0 - optuna: 3.1.1 - ordered-set: 4.1.0 - orjson: 3.8.10 - packaging: 23.1 - pandas: 1.5.3 - pandocfilters: 1.5.0 - parso: 0.8.3 - pastedeploy: 3.0.1 - patsy: 0.5.3 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.5.0 - pip: 23.1.1 - pkgutil-resolve-name: 1.3.10 - plaster: 1.0 - plaster-pastedeploy: 0.7 - platformdirs: 3.2.0 - ply: 3.11 - prometheus-client: 0.16.0 - prompt-toolkit: 3.0.38 - protobuf: 4.21.12 - psutil: 5.9.5 - ptxcompiler: 0.7.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 10.0.1 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycparser: 2.21 - pydantic: 1.10.7 - pygments: 2.15.1 - pyjwt: 2.6.0 - pymeeus: 0.5.12 - pyparsing: 3.0.9 - pyqt5: 5.15.7 - pyqt5-sip: 12.11.0 - pyramid: 2.0.1 - pyrsistent: 0.19.3 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-json-logger: 2.0.7 - python-multipart: 0.0.6 - pytorch-forecasting: 1.0.0 - pytorch-lightning: 2.0.1.post0 - pytorch-optimizer: 2.5.1 - pytz: 2023.3 - pyyaml: 6.0 - pyzmq: 25.0.2 - qtconsole: 5.4.2 - qtpy: 2.3.1 - rapids: 0.0.1 - readchar: 4.0.5 - requests: 2.28.2 - requests-oauthlib: 1.3.1 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.3.3 - rmm: 23.4.0 - rsa: 4.9 - scikit-learn: 1.2.2 - scipy: 1.10.1 - send2trash: 1.8.0 - setuptools: 67.7.1 - sip: 6.7.9 - six: 1.16.0 - sniffio: 1.3.0 - soupsieve: 2.3.2.post1 - sqlalchemy: 2.0.9 - stack-data: 0.6.2 - starlette: 0.22.0 - starsessions: 1.3.0 - statsmodels: 0.13.5 - sympy: 1.11.1 - tensorboard: 2.12.2 - tensorboard-data-server: 0.7.0 - tensorboard-plugin-wit: 1.8.1 - tensorflow-io-gcs-filesystem: 0.32.0 - termcolor: 2.2.0 - terminado: 0.17.1 - threadpoolctl: 3.1.0 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - torch: 2.0.0 - torchmetrics: 0.11.4 - tornado: 6.3 - tqdm: 4.65.0 - traitlets: 5.9.0 - translationstring: 1.4 - triton: 2.0.0 - typing-extensions: 4.5.0 - ujson: 5.7.0 - urllib3: 1.26.15 - uvicorn: 0.21.1 - uvloop: 0.17.0 - venusian: 3.0.0 - watchfiles: 0.19.0 - wcwidth: 0.2.6 - webencodings: 0.5.1 - webob: 1.8.7 - websocket-client: 1.5.1 - websockets: 11.0.1 - werkzeug: 2.2.3 - wheel: 0.40.0 - widgetsnbextension: 4.0.7 - yarl: 1.8.2 - zipp: 3.15.0 - zope.deprecation: 4.4.0 - zope.interface: 6.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.10 - version: #76-Ubuntu SMP Fri Mar 17 17:19:29 UTC 2023

More info

No response

arrow9577 commented 1 year ago

how to fix it? i also met this problem

hgersten5 commented 1 year ago

Was there a solution for this? I am running into the same error

ruuttt commented 10 months ago

You can find a solution here