unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.91k stars 857 forks source link

NBEATS mismatch for tensor on GPU vs CPU #298

Closed rmenzie closed 3 years ago

rmenzie commented 3 years ago

It appears when using NBEATs model while Generic_architecture is set to False that there is a tensor assigned to CPU while the rest is GPU: RuntimeError Traceback (most recent call last)

in 7 model_name='nbeats_run_interp') 8 ----> 9 nbeatsModel.fit(rescaled_train, val_series = rescaled_test, verbose = True) ~/.local/lib/python3.8/site-packages/darts/utils/torch.py in decorator(self, *args, **kwargs) 63 with fork_rng(): 64 manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE)) ---> 65 decorated(self, *args, **kwargs) 66 return decorator ~/.local/lib/python3.8/site-packages/darts/models/torch_forecasting_model.py in fit(self, series, covariates, val_series, val_covariates, verbose) 292 logger.info('Train dataset contains {} samples.'.format(len(train_dataset))) 293 --> 294 self.fit_from_dataset(train_dataset, val_dataset, verbose) 295 296 @random_method ~/.local/lib/python3.8/site-packages/darts/utils/torch.py in decorator(self, *args, **kwargs) 63 with fork_rng(): 64 manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE)) ---> 65 decorated(self, *args, **kwargs) 66 return decorator ~/.local/lib/python3.8/site-packages/darts/models/torch_forecasting_model.py in fit_from_dataset(self, train_dataset, val_dataset, verbose) 345 346 # Train model --> 347 self._train(train_loader, val_loader, tb_writer, verbose) 348 349 # Close tensorboard writer ~/.local/lib/python3.8/site-packages/darts/models/torch_forecasting_model.py in _train(self, train_loader, val_loader, tb_writer, verbose) 499 self.model.train() 500 data, target = data.to(self.device), target.to(self.device) --> 501 output = self.model(data) 502 loss = self.criterion(output, target) 503 self.optimizer.zero_grad() ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 725 result = self._slow_forward(*input, **kwargs) 726 else: --> 727 result = self.forward(*input, **kwargs) 728 for hook in itertools.chain( 729 _global_forward_hooks.values(), ~/.local/lib/python3.8/site-packages/darts/models/nbeats.py in forward(self, x) 306 for stack in self.stacks_list: 307 # compute stack output --> 308 stack_residual, stack_forecast = stack(x) 309 310 # add stack forecast to final output ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 725 result = self._slow_forward(*input, **kwargs) 726 else: --> 727 result = self.forward(*input, **kwargs) 728 for hook in itertools.chain( 729 _global_forward_hooks.values(), ~/.local/lib/python3.8/site-packages/darts/models/nbeats.py in forward(self, x) 210 for block in self.blocks_list: 211 # pass input through block --> 212 x_hat, y_hat = block(x) 213 214 # add block forecast to stack forecast ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 725 result = self._slow_forward(*input, **kwargs) 726 else: --> 727 result = self.forward(*input, **kwargs) 728 for hook in itertools.chain( 729 _global_forward_hooks.values(), ~/.local/lib/python3.8/site-packages/darts/models/nbeats.py in forward(self, x) 137 138 # waveform generator applications --> 139 x_hat = self.backcast_g(theta_backcast) 140 y_hat = self.forecast_g(theta_forecast) 141 ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 725 result = self._slow_forward(*input, **kwargs) 726 else: --> 727 result = self.forward(*input, **kwargs) 728 for hook in itertools.chain( 729 _global_forward_hooks.values(), ~/.local/lib/python3.8/site-packages/darts/models/nbeats.py in forward(self, x) 36 37 def forward(self, x): ---> 38 return torch.matmul(x, self.T.float().T) 39 40 RuntimeError: Tensor for argument #3 'mat2' is on CPU, but expected it to be on GPU (while checking arguments for addmm)
hrzn commented 3 years ago

Fixed in v0.9.0