unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
8.12k stars 885 forks source link

AttributeError: 'Tensor' object has no attribute 'tile' #468

Closed Sigvesor closed 3 years ago

Sigvesor commented 3 years ago

Hello,

I just trained my first model on darts while using a gpu and cuda. All worked fine through training, but I ran into an issue when I go for a model prediction.

AttributeError: 'Tensor' object has no attribute 'tile'

I am running with: darts==0.11.0 torch==1.7.0

# Interpretable model
model_nbeats = NBEATSModel(
    input_chunk_length=168,
    output_chunk_length=48,
    generic_architecture=False,
    num_stacks=5,
    num_blocks=5,
    num_layers=4,
    layer_widths=168,
    n_epochs=30,
    nr_epochs_val_period=1,
    batch_size=800,
    model_name='nbeats_interpretable_cuda',
    force_reset=True,
    torch_device_str = "cuda:0",

)
model_nbeats.fit(series=ts_train_list, val_series=val_ts_train, verbose=True) 
# ts_train_list is a list of TimeSeries
# val_ts_train is a TimeSeries
model_nbeats.predict(n=48, series=ts_train[1]) # ts_train[1] is a TimeSeries

~\Anaconda3\envs\darts-venv\lib\site-packages\darts\models\torch_forecasting_model.py in _sample_tiling(self, input_data_tuple, batch_sample_size)
    750         for tensor in input_data_tuple:
    751             if tensor is not None:
--> 752                 tiled_input_data.append(tensor.tile((batch_sample_size, 1, 1)))
    753             else:
    754                 tiled_input_data.append(None)

AttributeError: 'Tensor' object has no attribute 'tile'

I have tried to do some research, but couldn't find much. Any idea on how to resolve this issue?

dennisbader commented 3 years ago

Hi Sigvesor,

Our requirements for torch are >=1.8.0, <1.9.0 (see here).

Could you try to update torch and see if it resolves your issue?

Sigvesor commented 3 years ago

Hi @dennisbader,

I upgraded: pytorch 1.7.0-py3.7_cuda101_cudnn7_0 --> 1.8.0-py3.7_cuda10.1_cudnn7_0 torchvision 0.8.1-py37_cu101 --> 0.9.0-py37_cu101

This fixed the issue, thank you Dennis!