unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.91k stars 857 forks source link

[BUG] Tensor for argument #2 'mat1' is on CPU, but expected it to be on GPU (while checking arguments for addmm) #492

Closed JulianNyarko closed 2 years ago

JulianNyarko commented 2 years ago

Similar to #453, but training NBEATS() with a val_series, I am getting:

RuntimeError: Tensor for argument #2 'mat1' is on CPU, but expected it to be on GPU (while checking arguments for addmm)

hrzn commented 2 years ago

Hi @JulianNyarko , what version of Darts are you using?

JulianNyarko commented 2 years ago

Hi @hrzn , I’m using 0.12.0

hrzn commented 2 years ago

Ok thanks. Could you post the code snippet causing the issue along with the complete stack trace? Thanks!

JulianNyarko commented 2 years ago

Sure! My first time posting an issue, so let me know if I am not giving you everything you need!

My code snippet:

model_NBEATS = NBEATSModel(input_chunk_length=1, output_chunk_length=1, n_epochs=epochs, force_reset=True,
                          batch_size=1000, torch_device_str="cuda:0")

model_NBEATS.fit(series=ts_list_train, verbose=False, past_covariates=cov_list_train, val_series=ts_list_val, 
                     val_past_covariates=cov_list_val
                    )

The issue:

RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_2099723/3260355276.py in <module>
    107     model_NBEATS = NBEATSModel(input_chunk_length=1, output_chunk_length=1, n_epochs=epochs, force_reset=True,
    108                           batch_size=1000, torch_device_str="cuda:0")
--> 109     model_NBEATS.fit(series=ts_list_train, verbose=False, past_covariates=cov_list_train, val_series=ts_list_val, 
    110                      val_past_covariates=cov_list_val
    111                     )

~/darts/lib/python3.8/site-packages/darts/utils/torch.py in decorator(self, *args, **kwargs)
     63         with fork_rng():
     64             manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
---> 65             return decorated(self, *args, **kwargs)
     66     return decorator

~/darts/lib/python3.8/site-packages/darts/models/torch_forecasting_model.py in fit(self, series, past_covariates, future_covariates, val_series, val_past_covariates, val_future_covariates, verbose, epochs)
    429         logger.info('Train dataset contains {} samples.'.format(len(train_dataset)))
    430 
--> 431         self.fit_from_dataset(train_dataset, val_dataset, verbose, epochs)
    432 
    433     @random_method

~/darts/lib/python3.8/site-packages/darts/utils/torch.py in decorator(self, *args, **kwargs)
     63         with fork_rng():
     64             manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
---> 65             return decorated(self, *args, **kwargs)
     66     return decorator

~/darts/lib/python3.8/site-packages/darts/models/torch_forecasting_model.py in fit_from_dataset(self, train_dataset, val_dataset, verbose, epochs)
    517 
    518         # Train model
--> 519         self._train(train_loader, val_loader, tb_writer, verbose, train_num_epochs)
    520 
    521         # Close tensorboard writer

~/darts/lib/python3.8/site-packages/darts/models/torch_forecasting_model.py in _train(self, train_loader, val_loader, tb_writer, verbose, epochs)
    823                 training_loss = total_loss / len(train_loader)
    824                 if val_loader is not None:
--> 825                     validation_loss = self._evaluate_validation_loss(val_loader)
    826                     if tb_writer is not None:
    827                         tb_writer.add_scalar("validation/loss_total", validation_loss, epoch)

~/darts/lib/python3.8/site-packages/darts/models/torch_forecasting_model.py in _evaluate_validation_loss(self, val_loader)
    848         with torch.no_grad():
    849             for batch_idx, val_batch in enumerate(val_loader):
--> 850                 output = self._produce_train_output(val_batch[:-1])
    851                 target = val_batch[-1]
    852                 loss = self._compute_loss(output, target)

~/darts/lib/python3.8/site-packages/darts/models/torch_forecasting_model.py in _produce_train_output(self, input_batch)
   1109         # Currently all our PastCovariates models require past target and covariates concatenated
   1110         inpt = torch.cat([past_target, past_covariate], dim=2) if past_covariate is not None else past_target
-> 1111         return self.model(inpt)
   1112 
   1113     def _get_batch_prediction(self, n: int, input_batch: Tuple, roll_size: int) -> Tensor:

~/darts/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/darts/lib/python3.8/site-packages/darts/models/nbeats.py in forward(self, x)
    339         for stack in self.stacks_list:
    340             # compute stack output
--> 341             stack_residual, stack_forecast = stack(x)
    342 
    343             # add stack forecast to final output

~/darts/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/darts/lib/python3.8/site-packages/darts/models/nbeats.py in forward(self, x)
    213         for block in self.blocks_list:
    214             # pass input through block
--> 215             x_hat, y_hat = block(x)
    216 
    217             # add block forecast to stack forecast

~/darts/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/darts/lib/python3.8/site-packages/darts/models/nbeats.py in forward(self, x)
    133         # fully connected layer stack
    134         for layer in self.linear_layer_stack_list:
--> 135             x = self.relu(layer(x))
    136 
    137         # forked linear layers producing waveform generator parameters

~/darts/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/darts/lib/python3.8/site-packages/torch/nn/modules/linear.py in forward(self, input)
     92 
     93     def forward(self, input: Tensor) -> Tensor:
---> 94         return F.linear(input, self.weight, self.bias)
     95 
     96     def extra_repr(self) -> str:

~/darts/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1751     if has_torch_function_variadic(input, weight):
   1752         return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
-> 1753     return torch._C._nn.linear(input, weight, bias)
   1754 
   1755 

RuntimeError: Tensor for argument #2 'mat1' is on CPU, but expected it to be on GPU (while checking arguments for addmm)
hrzn commented 2 years ago

OK thanks, it looks like you are using 0.10.1 or before, because your stack trace for torch_forecasting_model.py shows that line 850 is not sending the validation set to the correct device, and this line has changed in v0.11.0 which fixes the issue. Please run pip install -U darts and try again. You can double check what version you're using by running pip freeze |grep darts.

JulianNyarko commented 2 years ago

That is really odd. pip definitely tells me I am running 0.12.0...

image

But if you are saying its a me issue, I'll see if there is anything I can do to fix it. Thanks!

hrzn commented 2 years ago

The stack trace really does seem to indicate this. Perhaps you are executing your code (or notebook) in a different virtual environment?

JulianNyarko commented 2 years ago

Hm, no, that's not it... But I'll try to do a fresh install of darts tomorrow. Hopefully, that'll fix it!

hrzn commented 2 years ago

OK, please let me know.

JulianNyarko commented 2 years ago

Working now. Thanks again!

hrzn commented 2 years ago

Great 👍