awslabs / gluonts

Probabilistic time series modeling in Python
https://ts.gluon.ai
Apache License 2.0
4.65k stars 756 forks source link

TemporalFusionTransformer fails on GPU if non-default num_outputs is provided #2264

Open shchur opened 2 years ago

shchur commented 2 years ago

Description

TemporalFusionTransformerEstimator crashes when training on GPU with num_outputs != 3 (any non-default value).

To Reproduce

import mxnet as mx
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.model.tft import TemporalFusionTransformerEstimator
from gluonts.mx.trainer import Trainer

dataset = get_dataset("m3_other")
ctx = mx.context.gpu()
model = TemporalFusionTransformerEstimator(
    trainer=Trainer(ctx=ctx),
    freq=dataset.metadata.freq,
    prediction_length=dataset.metadata.prediction_length,
    num_outputs=9,
)
model.train(dataset.train)

If I change ctx = mx.context.cpu() or set num_outputs=3, the problem disappears.

Error message or code output

Traceback (most recent call last):
  File "/local/home/shchuro/workspace/autogluon/shchuro/train_gluonts.py", line 14, in <module>
    model.train(dataset.train)
  File "/home/shchuro/miniconda3/envs/ag/lib/python3.9/site-packages/gluonts/mx/model/estimator.py", line 194, in train
    return self.train_model(
  File "/home/shchuro/miniconda3/envs/ag/lib/python3.9/site-packages/gluonts/mx/model/estimator.py", line 169, in train_model
    self.trainer(
  File "/home/shchuro/miniconda3/envs/ag/lib/python3.9/site-packages/gluonts/mx/trainer/_base.py", line 436, in __call__
    epoch_loss = loop(
  File "/home/shchuro/miniconda3/envs/ag/lib/python3.9/site-packages/gluonts/mx/trainer/_base.py", line 334, in loop
    _ = net(*batch.values())
  File "/home/shchuro/miniconda3/envs/ag/lib/python3.9/site-packages/mxnet/gluon/block.py", line 825, in __call__
    out = self.forward(*args)
  File "/home/shchuro/miniconda3/envs/ag/lib/python3.9/site-packages/mxnet/gluon/block.py", line 1492, in forward
    return self._call_cached_op(x, *args)
  File "/home/shchuro/miniconda3/envs/ag/lib/python3.9/site-packages/mxnet/gluon/block.py", line 1233, in _call_cached_op
    out = self._cached_op(*cargs)
  File "/home/shchuro/miniconda3/envs/ag/lib/python3.9/site-packages/mxnet/_ctypes/ndarray.py", line 148, in __call__
    check_call(_LIB.MXInvokeCachedOpEx(
  File "/home/shchuro/miniconda3/envs/ag/lib/python3.9/site-packages/mxnet/base.py", line 246, in check_call
    raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
  File "../3rdparty/tvm/nnvm/src/core/graph.cc", line 109
MXNetError: Check failed: it != node2index_.end(): control dep not found in graph

Environment

lostella commented 2 years ago

@shchur one observation: setting hybridize=False in the Trainer also makes the problem disappear.

lostella commented 2 years ago

Related:

shchur commented 2 years ago

@shchur one observation: setting hybridize=False in the Trainer also makes the problem disappear.

Thanks a lot for looking into this @lostella! I think hybridize=False is a good enough workaround for our purposes in AutoGluon (before switching to PyTorch).

lostella commented 1 year ago

@shchur I'm assuming the issue doesn't show up in the implementation from #2536? 🙃

shchur commented 1 year ago

No, the PyTorch version works fine on GPU :)