unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
8.14k stars 886 forks source link

Run on CPU not on gpu #722

Closed ElVictorious closed 2 years ago

ElVictorious commented 2 years ago

Hi,

First, congrats for the amazing job with this repo.

I experienced that the temporal fusion transformer do not works with GPU but it works well with cpus. I used 3 different computers and I had the same problem, even with colab. Same issue..

Thanks in advance for your help

model = TFTModel( input_chunk_length=INLEN, output_chunk_length=N_FC, hidden_size=HIDDEN, lstm_layers=LSTMLAYERS, num_attention_heads=ATTHEADS, dropout=DROPOUT, batch_size=BATCH, n_epochs=EPOCHS, likelihood=QuantileRegression(quantiles=QUANTILES),

loss_fn=MSELoss(),

                random_state=RAND, 
                force_reset=True)

model.fit( ts_ttrain, future_covariates=tcov, verbose=True)

RuntimeError Traceback (most recent call last)

in 15 model.fit( ts_ttrain, 16 future_covariates=tcov, ---> 17 verbose=True) 9 frames /usr/local/lib/python3.7/dist-packages/darts/utils/torch.py in decorator(self, *args, **kwargs) 63 with fork_rng(): 64 manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE)) ---> 65 return decorated(self, *args, **kwargs) 66 return decorator /usr/local/lib/python3.7/dist-packages/darts/models/forecasting/torch_forecasting_model.py in fit(self, series, past_covariates, future_covariates, val_series, val_past_covariates, val_future_covariates, verbose, epochs, max_samples_per_ts, num_loader_workers) 477 logger.info('Train dataset contains {} samples.'.format(len(train_dataset))) 478 --> 479 self.fit_from_dataset(train_dataset, val_dataset, verbose, epochs, num_loader_workers) 480 481 @property /usr/local/lib/python3.7/dist-packages/darts/utils/torch.py in decorator(self, *args, **kwargs) 63 with fork_rng(): 64 manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE)) ---> 65 return decorated(self, *args, **kwargs) 66 return decorator /usr/local/lib/python3.7/dist-packages/darts/models/forecasting/torch_forecasting_model.py in fit_from_dataset(self, train_dataset, val_dataset, verbose, epochs, num_loader_workers) 591 592 # Train model --> 593 self._train(train_loader, val_loader, tb_writer, verbose, train_num_epochs) 594 595 # Close tensorboard writer /usr/local/lib/python3.7/dist-packages/darts/models/forecasting/torch_forecasting_model.py in _train(self, train_loader, val_loader, tb_writer, verbose, epochs) 886 self.model.train() 887 train_batch = self._batch_to_device(train_batch) --> 888 output = self._produce_train_output(train_batch[:-1]) 889 target = train_batch[-1] # By convention target is always the last element returned by datasets 890 loss = self._compute_loss(output, target) /usr/local/lib/python3.7/dist-packages/darts/models/forecasting/tft_model.py in _produce_train_output(self, input_batch) 800 801 def _produce_train_output(self, input_batch: Tuple): --> 802 return self.model(input_batch) 803 804 def predict(self, n, *args, **kwargs): /usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101 or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(*input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], [] /usr/local/lib/python3.7/dist-packages/darts/models/forecasting/tft_model.py in forward(self, x) 447 448 # run local lstm encoder --> 449 encoder_out, (hidden, cell) = self.lstm_encoder(input=embeddings_varying_encoder, hx=(input_hidden, input_cell)) 450 451 # run local lstm decoder /usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101 or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(*input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], [] /usr/local/lib/python3.7/dist-packages/torch/nn/modules/rnn.py in forward(self, input, hx) 690 if batch_sizes is None: 691 result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers, --> 692 self.dropout, self.training, self.bidirectional, self.batch_first) 693 else: 694 result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias, RuntimeError: rnn: hx is not contiguous
dennisbader commented 2 years ago

I could reproduce the error on colab. This happens with lstm_layers > 1. Can you try with lstm_layers=1?

I assume the non-contiguity comes from expanding the hidden state tensors for the LSTM layers. I will test if we can fix this with contiguous()