dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.55k stars 470 forks source link

Making the computation of feature importance optional #493

Closed CesarLeblanc closed 1 year ago

CesarLeblanc commented 1 year ago

Feature request

After talking with colleagues that are also using TabNet and reading some precedent issues on this matter, it seems that the computation of feature importance is not always needed by users (e.g., when fine-tuning a model and looking for the best hyper-parameters to use, hundreds of combinations can be tried, and the computation of feature importance is not required for all those models but only once the best model is defined). Moreover, the computation of feature importance is very time-consuming (sometimes even longer than the training of the neural network itself, especially when the input data is high-dimensional). Therefore, I wish to add a compute_importance parameter to the fit method of the TabModel class. This parameter would be a boolean value that allows users to activate or deactivate the computation of feature importance during training. As the interpretability aspect of this model is very interesting, I believe it would be beneficial to have this parameter set to True by default.

What is the expected behavior?

When the compute_importance parameter is set to True, the fit method will compute the feature importance along with the training process. This behavior is consistent with the current behavior of the fit method.

When the compute_importance parameter is set to False, the fit method will skip the computation of feature importance entirely. This will significantly reduce the training time, especially in scenarios where the feature importance is not immediately required.

What is motivation or use case for adding/changing the behavior?

The motivation behind adding the compute_importance parameter is to provide users with more flexibility and control over the training process of TabNet. Currently, the computation of feature importance is performed for every training iteration, which can be unnecessary and time-consuming in certain scenarios.

By allowing users to disable the computation of feature importance during training, they can save significant computation time when performing hyper-parameter search or training multiple models. This is particularly valuable in situations where the feature importance is only needed once the best model has been determined.

How should this be implemented in your opinion?

To implement this feature, the following steps can be taken:

  1. Modify the fit method of the TabModel class to include the new compute_importance parameter. This parameter should default to True to maintain the current behavior.
  2. In the fit method, check the value of the compute_importance parameter before performing the computation of feature importance. If the parameter is set to False, skip the feature importance computation entirely.
  3. Update the documentation and code examples to reflect the addition of the compute_importance parameter and its usage.

Here's the new fit method:

    def fit(
        self,
        X_train,
        y_train,
        eval_set=None,
        eval_name=None,
        eval_metric=None,
        loss_fn=None,
        weights=0,
        max_epochs=100,
        patience=10,
        batch_size=1024,
        virtual_batch_size=128,
        num_workers=0,
        drop_last=True,
        callbacks=None,
        pin_memory=True,
        from_unsupervised=None,
        warm_start=False,
        augmentations=None,
        compute_importance=True
    ):
        """Train a neural network stored in self.network
        Using train_dataloader for training data and
        valid_dataloader for validation.

        Parameters
        ----------
        X_train : np.ndarray
            Train set
        y_train : np.array
            Train targets
        eval_set : list of tuple
            List of eval tuple set (X, y).
            The last one is used for early stopping
        eval_name : list of str
            List of eval set names.
        eval_metric : list of str
            List of evaluation metrics.
            The last metric is used for early stopping.
        loss_fn : callable or None
            a PyTorch loss function
        weights : bool or dictionnary
            0 for no balancing
            1 for automated balancing
            dict for custom weights per class
        max_epochs : int
            Maximum number of epochs during training
        patience : int
            Number of consecutive non improving epoch before early stopping
        batch_size : int
            Training batch size
        virtual_batch_size : int
            Batch size for Ghost Batch Normalization (virtual_batch_size < batch_size)
        num_workers : int
            Number of workers used in torch.utils.data.DataLoader
        drop_last : bool
            Whether to drop last batch during training
        callbacks : list of callback function
            List of custom callbacks
        pin_memory: bool
            Whether to set pin_memory to True or False during training
        from_unsupervised: unsupervised trained model
            Use a previously self supervised model as starting weights
        warm_start: bool
            If True, current model parameters are used to start training
        compute_importance : bool
            Whether to compute feature importance
        """
        # update model name

        self.max_epochs = max_epochs
        self.patience = patience
        self.batch_size = batch_size
        self.virtual_batch_size = virtual_batch_size
        self.num_workers = num_workers
        self.drop_last = drop_last
        self.input_dim = X_train.shape[1]
        self._stop_training = False
        self.pin_memory = pin_memory and (self.device.type != "cpu")
        self.augmentations = augmentations
        self.compute_importance = compute_importance

        if self.augmentations is not None:
            # This ensure reproducibility
            self.augmentations._set_seed()

        eval_set = eval_set if eval_set else []

        if loss_fn is None:
            self.loss_fn = self._default_loss
        else:
            self.loss_fn = loss_fn

        check_input(X_train)
        check_warm_start(warm_start, from_unsupervised)

        self.update_fit_params(
            X_train,
            y_train,
            eval_set,
            weights,
        )

        # Validate and reformat eval set depending on training data
        eval_names, eval_set = validate_eval_set(eval_set, eval_name, X_train, y_train)

        train_dataloader, valid_dataloaders = self._construct_loaders(
            X_train, y_train, eval_set
        )

        if from_unsupervised is not None:
            # Update parameters to match self pretraining
            self.__update__(**from_unsupervised.get_params())

        if not hasattr(self, "network") or not warm_start:
            # model has never been fitted before of warm_start is False
            self._set_network()
        self._update_network_params()
        self._set_metrics(eval_metric, eval_names)
        self._set_optimizer()
        self._set_callbacks(callbacks)

        if from_unsupervised is not None:
            self.load_weights_from_unsupervised(from_unsupervised)
            warnings.warn("Loading weights from unsupervised pretraining")
        # Call method on_train_begin for all callbacks
        self._callback_container.on_train_begin()

        # Training loop over epochs
        for epoch_idx in range(self.max_epochs):

            # Call method on_epoch_begin for all callbacks
            self._callback_container.on_epoch_begin(epoch_idx)

            self._train_epoch(train_dataloader)

            # Apply predict epoch to all eval sets
            for eval_name, valid_dataloader in zip(eval_names, valid_dataloaders):
                self._predict_epoch(eval_name, valid_dataloader)

            # Call method on_epoch_end for all callbacks
            self._callback_container.on_epoch_end(
                epoch_idx, logs=self.history.epoch_metrics
            )

            if self._stop_training:
                break

        # Call method on_train_end for all callbacks
        self._callback_container.on_train_end()
        self.network.eval()

        if self.compute_importance:
            # compute feature importance once the best model is defined
            self.feature_importances_ = self._compute_feature_importances(X_train)

Are you willing to work on this yourself? Yes, I am willing to work on implementing this feature.

Optimox commented 1 year ago

@CesarLeblanc Yes certainly, please feel free to open a PR

CesarLeblanc commented 1 year ago

@Optimox Is it fine if I open a PR for both this Issue (#493) and the one I made two days ago about working with sparse data (#492)?

Optimox commented 1 year ago

@CesarLeblanc please open two separate PRs otherwise things are going to get messy

CesarLeblanc commented 1 year ago

@Optimox you are right. I'm fixing this issue first (see PR #494), and once it is merged I will open a new PR to allow working with sparse data.

Optimox commented 1 year ago

closed by #494