dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.55k stars 470 forks source link

Fitting TabNetClassifier on sparse data #492

Closed CesarLeblanc closed 11 months ago

CesarLeblanc commented 1 year ago

Feature request

For the moment, if we try to work with sparse matrices in the fit method of TabNetClassifier (whether it's for X_train or for eval_set), an error message will be raised from _ensure_sparse_format in sklearn.utils.validation.py: TypeError: A sparse matrix was passed, but dense data is required. Use X.toarray() to convert to a dense numpy array. This is because of the check_input method in pytorch_tabnet.utils.py that doesn't have the argument accept_sparse=True. However, we only need to change a small piece of code and write two new Datasets such as:

class SparsePredictDataset(Dataset):
    """
    Custom dataset class from predictions with sparse input data

    Parameters
    ----------
    X : scipy.sparse.scr_matrix
        The sparse input matrix
    """

    def __init__(self, X):
        self.X = X

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, index):
        X = torch.from_numpy(self.X[index].toarray()[0]).float()
        return X

class SparseTorchDataset(Dataset):
    """
    Custom dataset class for sparse input data

    Parameters
    ----------
    X : scipy.sparse.csr_matrix
        The sparse input matrix
    y : numpy.ndarray
        The target data
    """

    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, index):
        X = torch.from_numpy(self.X[index].toarray()[0]).float()
        y = self.y[index]
        return X, y

Then, we should modify the method that checks the input data for compatibility with the model as follow:

def check_input(X):
    """
    Check the input data for compatibility with the model.
    Parameters:
        X: array-like or sparse matrix, shape (n_samples, n_features)
            The input data.

    Raises:
        TypeError: If `X` is a Pandas DataFrame or Series.

    Returns:
        None
    """
    if isinstance(X, (pd.DataFrame, pd.Series)):  # Raise an error if input is a Pandas DataFrame or Series
        err_message = "Pandas DataFrame is not supported: apply X.values when calling fit"
        raise TypeError(err_message)
    sklearn.utils.check_array(X, accept_sparse=True)  # Check the input array for compatibility with the model

Once it's done, the only change would be to check if the input data is sparse (for example using scipy.sparse.issparse) and then construct the DataLoader using the normal Dataset or the sparse Dataset accordingly.

What is the expected behavior?

The expected behavior is that the fit method of TabNetClassifier should support sparse matrices for both X_train and eval_set inputs without raising a TypeError related to dense data requirement.

What is motivation or use case for adding/changing the behavior?

Sparse matrices are commonly used in various machine learning tasks, especially when dealing with high-dimensional data. By adding support for sparse matrices in the fit method of TabNetClassifier, users will have more flexibility and efficiency when working with sparse input data.

How should this be implemented in your opinion?

The proposed implementation involves modifying the check_input method in pytorch_tabnet.utils.py to accept sparse matrices by adding the accept_sparse=True argument. Additionally, two new custom Dataset classes, SparsePredictDataset and SparseTorchDataset, should be implemented to handle sparse input data. These classes convert the sparse matrices to dense numpy arrays using toarray() and create PyTorch tensors from the dense arrays.

To complete the implementation, we need to modify the code that constructs the DataLoader based on the input data. We should check if the input data is sparse using scipy.sparse.issparse and use the appropriate Dataset class accordingly.

Here's the suggested implementation, we should be done in each method that constructs a DataLoader (i.e., in predict, in explain, predict_proba and in create_dataloaders):

def create_dataloaders(
    X_train, y_train, eval_set, weights, batch_size, num_workers, drop_last, pin_memory
):
    """
    Create dataloaders with or without subsampling depending on weights and balanced.

    Parameters
    ----------
    X_train : np.ndarray
        Training data
    y_train : np.array
        Mapped Training targets
    eval_set : list of tuple
        List of eval tuple set (X, y)
    weights : either 0, 1, dict or iterable
        if 0 (default) : no weights will be applied
        if 1 : classification only, will balanced class with inverse frequency
        if dict : keys are corresponding class values are sample weights
        if iterable : list or np array must be of length equal to nb elements
                      in the training set
    batch_size : int
        how many samples per batch to load
    num_workers : int
        how many subprocesses to use for data loading. 0 means that the data
        will be loaded in the main process
    drop_last : bool
        set to True to drop the last incomplete batch, if the dataset size is not
        divisible by the batch size. If False and the size of dataset is not
        divisible by the batch size, then the last batch will be smaller
    pin_memory : bool
        Whether to pin GPU memory during training

    Returns
    -------
    train_dataloader, valid_dataloader : torch.DataLoader, torch.DataLoader
        Training and validation dataloaders
    """
    need_shuffle, sampler = create_sampler(weights, y_train)

    if not issparse(X):
        train_dataloader = DataLoader(
            TorchDataset(X_train.astype(np.float32), y_train),
            batch_size=batch_size,
            sampler=sampler,
            shuffle=need_shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
        )
    else:
        train_dataloader = DataLoader(
            SparseTorchDataset(X_train.astype(np.float32), y_train),
            batch_size=batch_size,
            sampler=sampler,
            shuffle=need_shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
        )

    valid_dataloaders = []
    for X, y in eval_set:
        if not issparse(X):
            valid_dataloaders.append(
                DataLoader(
                    TorchDataset(X.astype(np.float32), y),
                    batch_size=batch_size,
                    shuffle=False,
                    num_workers=num_workers,
                    pin_memory=pin_memory,
                )
            )
        else:
            valid_dataloaders.append(
                DataLoader(
                    SparseTorchDataset(X.astype(np.float32), y),
                    batch_size=batch_size,
                    shuffle=False,
                    num_workers=num_workers,
                    pin_memory=pin_memory,
                )
            )

    return train_dataloader, valid_dataloaders

By making these changes, the TabNetClassifier will support both dense and sparse input data, allowing users to work with sparse matrices seamlessly.

Are you willing to work on this yourself? Yes, I am willing to work on implementing these changes.