dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.56k stars 473 forks source link

CUDA error: device-side assert triggered #432

Closed ooghry closed 1 year ago

ooghry commented 1 year ago

I'm trying to use Optuna with TabNet on Google Colab. At some point, I got this error. (I do not know if it's a natural trace or if other things hide behind the Cuda.)

[W 2022-09-09 01:11:06,404] Trial 2 failed because of the following error: RuntimeError('CUDA error: device-side assert triggered')
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/optuna/study/_optimize.py", line 196, in _run_trial
    value_or_values = func(trial)
  File "<ipython-input-20-520550c8acce>", line 79, in __call__
    raise e
  File "<ipython-input-20-520550c8acce>", line 64, in __call__
    callbacks=[my_callback],
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/abstract_model.py", line 223, in fit
    self._train_epoch(train_dataloader)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/abstract_model.py", line 434, in _train_epoch
    batch_logs = self._train_batch(X, y)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/abstract_model.py", line 469, in _train_batch
    output, M_loss = self.network(X)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/tab_network.py", line 583, in forward
    return self.tabnet(x)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/tab_network.py", line 468, in forward
    steps_output, M_loss = self.encoder(x)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/tab_network.py", line 160, in forward
    M = self.att_transformers[step](prior, att)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/tab_network.py", line 637, in forward
    x = self.selector(x)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/sparsemax.py", line 109, in forward
    return sparsemax(input, self.dim)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/sparsemax.py", line 52, in forward
    tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/sparsemax.py", line 94, in _threshold_and_support
    tau = input_cumsum.gather(dim, support_size - 1)
RuntimeError: CUDA error: device-side assert triggered
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-21-cfe1c0346e47>](https://localhost:8080/#) in <module>
     24   else:
     25     n_trials = 200
---> 26   study.optimize(objective, n_trials=n_trials, gc_after_trial=True, )
     27   del x_train,x_val,y_train,y_val,study,objective
     28   gc.collect()

21 frames
[/usr/local/lib/python3.7/dist-packages/optuna/study/study.py](https://localhost:8080/#) in optimize(self, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)
    426             callbacks=callbacks,
    427             gc_after_trial=gc_after_trial,
--> 428             show_progress_bar=show_progress_bar,
    429         )
    430 

[/usr/local/lib/python3.7/dist-packages/optuna/study/_optimize.py](https://localhost:8080/#) in _optimize(study, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)
     74                 reseed_sampler_rng=False,
     75                 time_start=None,
---> 76                 progress_bar=progress_bar,
     77             )
     78         else:

[/usr/local/lib/python3.7/dist-packages/optuna/study/_optimize.py](https://localhost:8080/#) in _optimize_sequential(study, func, n_trials, timeout, catch, callbacks, gc_after_trial, reseed_sampler_rng, time_start, progress_bar)
    158 
    159         try:
--> 160             frozen_trial = _run_trial(study, func, catch)
    161         finally:
    162             # The following line mitigates memory problems that can be occurred in some

[/usr/local/lib/python3.7/dist-packages/optuna/study/_optimize.py](https://localhost:8080/#) in _run_trial(study, func, catch)
    232         and not isinstance(func_err, catch)
    233     ):
--> 234         raise func_err
    235     return frozen_trial
    236 

[/usr/local/lib/python3.7/dist-packages/optuna/study/_optimize.py](https://localhost:8080/#) in _run_trial(study, func, catch)
    194     with get_heartbeat_thread(trial._trial_id, study._storage):
    195         try:
--> 196             value_or_values = func(trial)
    197         except exceptions.TrialPruned as e:
    198             # TODO(mamu): Handle multi-objective cases.

[<ipython-input-20-520550c8acce>](https://localhost:8080/#) in __call__(self, trial)
     77             return e.args[0]["max"]
     78         except Exception as e:
---> 79             raise e
     80             del classifier,my_callback
     81             torch.cuda.empty_cache()

[<ipython-input-20-520550c8acce>](https://localhost:8080/#) in __call__(self, trial)
     62                   weights={0:2.,1:2.,2:.001},
     63                   pin_memory=False,
---> 64                   callbacks=[my_callback],
     65             )
     66             preds = classifier.predict_proba(self.X_test.values)

[/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/abstract_model.py](https://localhost:8080/#) in fit(self, X_train, y_train, eval_set, eval_name, eval_metric, loss_fn, weights, max_epochs, patience, batch_size, virtual_batch_size, num_workers, drop_last, callbacks, pin_memory, from_unsupervised)
    221             self._callback_container.on_epoch_begin(epoch_idx)
    222 
--> 223             self._train_epoch(train_dataloader)
    224 
    225             # Apply predict epoch to all eval sets

[/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/abstract_model.py](https://localhost:8080/#) in _train_epoch(self, train_loader)
    432             self._callback_container.on_batch_begin(batch_idx)
    433 
--> 434             batch_logs = self._train_batch(X, y)
    435 
    436             self._callback_container.on_batch_end(batch_idx, batch_logs)

[/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/abstract_model.py](https://localhost:8080/#) in _train_batch(self, X, y)
    467             param.grad = None
    468 
--> 469         output, M_loss = self.network(X)
    470 
    471         loss = self.compute_loss(output, y)

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/tab_network.py](https://localhost:8080/#) in forward(self, x)
    581     def forward(self, x):
    582         x = self.embedder(x)
--> 583         return self.tabnet(x)
    584 
    585     def forward_masks(self, x):

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/tab_network.py](https://localhost:8080/#) in forward(self, x)
    466     def forward(self, x):
    467         res = 0
--> 468         steps_output, M_loss = self.encoder(x)
    469         res = torch.sum(torch.stack(steps_output, dim=0), dim=0)
    470 

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/tab_network.py](https://localhost:8080/#) in forward(self, x, prior)
    158         steps_output = []
    159         for step in range(self.n_steps):
--> 160             M = self.att_transformers[step](prior, att)
    161             M_loss += torch.mean(
    162                 torch.sum(torch.mul(M, torch.log(M + self.epsilon)), dim=1)

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/tab_network.py](https://localhost:8080/#) in forward(self, priors, processed_feat)
    635         x = self.bn(x)
    636         x = torch.mul(x, priors)
--> 637         x = self.selector(x)
    638         return x
    639 

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/sparsemax.py](https://localhost:8080/#) in forward(self, input)
    107 
    108     def forward(self, input):
--> 109         return sparsemax(input, self.dim)
    110 
    111 

[/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/sparsemax.py](https://localhost:8080/#) in forward(ctx, input, dim)
     50         max_val, _ = input.max(dim=dim, keepdim=True)
     51         input -= max_val  # same numerical stability trick as for softmax
---> 52         tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim)
     53         output = torch.clamp(input - tau, min=0)
     54         ctx.save_for_backward(supp_size, output)

[/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/sparsemax.py](https://localhost:8080/#) in _threshold_and_support(input, dim)
     92 
     93         support_size = support.sum(dim=dim).unsqueeze(dim)
---> 94         tau = input_cumsum.gather(dim, support_size - 1)
     95         tau /= support_size.to(input.dtype)
     96         return tau, support_size

RuntimeError: CUDA error: device-side assert triggered

The only weird thing I do is weight={0:2.,1:2.,2:.001} argument:

class my_metric(Metric):
    def __init__(self):
        self._name = "custom" # write an understandable name here
        self._maximize = True

    def __call__(self, y_true, y_score):
      y_pred = np.argmax(y_score, axis=1)
      confusion = confusion_matrix(y_true, y_pred, labels=[0,1,2])
      return (((confusion[0][0] + confusion[1][1])*2) - (confusion[0][1] + confusion[1][0] + confusion[2][0] + confusion[2][1]))
mask_type = trial.suggest_categorical("mask_type", ["entmax", "sparsemax"])

n_da = trial.suggest_int("n_da", 8, 512, step=4)
n_steps = trial.suggest_int("n_steps", 3, 30, step=1)
gamma = trial.suggest_float("gamma", 1., 4., step=0.2)
n_independent = trial.suggest_int("n_independent", 1, 15, step=1)
n_shared = trial.suggest_int("n_shared", 1, 15)

lambda_sparse = trial.suggest_float("lambda_sparse", 1e-6, 1e-3, log=True)

tabnet_params = dict(
    # device_name='cuda',
    n_d=n_da, 
    n_a=n_da, 
    n_steps=n_steps, 
    gamma=gamma,
    n_independent=n_independent,
    n_shared=n_shared,
    lambda_sparse=lambda_sparse, 
    optimizer_fn=torch.optim.Adam,
    scheduler_fn=torch.optim.lr_scheduler.ReduceLROnPlateau,
    scheduler_params=dict(
        mode="min",
        patience=trial.suggest_int("patienceScheduler",low=3,high=10), # changing sheduler patience to be lower than early stopping patience 
        min_lr=1e-5,
        factor=0.5,
    ),
    mask_type=mask_type, 
    verbose=1,
    seed = 2022,
)
classifier = TabNetClassifier(**tabnet_params)
my_callback = custom_callback()
classifier.fit(
      np.array(self.X_train),
      np.array(self.y_train.values.ravel()),
      eval_set = [(np.array(self.X_validation), np.array(self.y_validation.values.ravel()))],
      max_epochs = 200,
      patience = 20,
      batch_size = 256,
      eval_metric=[my_metric,],
      weights={0:2.,1:2.,2:.001},
      pin_memory=False,
      callbacks=[my_callback],
)

Any comment would be appreciated.

ooghry commented 1 year ago

Running on CPU:

RuntimeError('index -1 is out of bounds for dimension 1 with size 10260')
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/optuna/study/_optimize.py", line 196, in _run_trial
    value_or_values = func(trial)
  File "<ipython-input-12-e61dd5c8773c>", line 64, in __call__
    callbacks=[my_callback],
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/abstract_model.py", line 223, in fit
    self._train_epoch(train_dataloader)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/abstract_model.py", line 434, in _train_epoch
    batch_logs = self._train_batch(X, y)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/abstract_model.py", line 469, in _train_batch
    output, M_loss = self.network(X)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/tab_network.py", line 583, in forward
    return self.tabnet(x)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/tab_network.py", line 468, in forward
    steps_output, M_loss = self.encoder(x)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/tab_network.py", line 160, in forward
    M = self.att_transformers[step](prior, att)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/tab_network.py", line 637, in forward
    x = self.selector(x)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/sparsemax.py", line 204, in forward
    return entmax15(input, self.dim)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/sparsemax.py", line 127, in forward
    tau_star, _ = Entmax15Function._threshold_and_support(input, dim)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_tabnet/sparsemax.py", line 159, in _threshold_and_support
    tau_star = tau.gather(dim, support_size - 1)
RuntimeError: index -1 is out of bounds for dimension 1 with size 10260
Optimox commented 1 year ago

@ooghry I think this error happens in a few cases that you should investigate:

It's probably one or the other, let me know if it's not.

ooghry commented 1 year ago

@Optimox Thank you for your reply. I don't have any NaN or any categorical feature in my data. I add cat_idxs=[],cat_dims=[],cat_emb_dim=1 as argument, it seems my first problem gone.

But my notebook crashes after some trials on Google Colab(25Gb ram) because of memory overflow. It would be nice if we had control over how much memory TabNet can use.

Optimox commented 1 year ago

So you were probably using embeddings but something was wrong in the indexing. Memory overflow might happen and I'll be happy to discuss how we can reduce tabnet consumption. But I think as long as there is no memory leak it's alright. Just reduce your batch size if needed, or the size of your model if you don't have enough computing power.