dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.55k stars 470 forks source link

Fitting TabNetClassifier on sparse data #495

Closed CesarLeblanc closed 11 months ago

CesarLeblanc commented 1 year ago

Closes #492

IMPORTANT: Please do not create a Pull Request without creating an issue first.

Any change needs to be discussed before proceeding. Failure to do so may result in the rejection of the pull request.

What kind of change does this PR introduce?

A new feature: allowing the input data (both in train and validation/test) to be sparse.

Does this PR introduce a breaking change?

Not really, but users will have more flexibility and efficiency when working with sparse input data.

What needs to be documented once your changes are merged?

Only the READMe.md, but it has already been modified (the fact that X_train doesn't have to be a np.array but can also be scipy.sparse.csr_matrix).

Closing issues

Put closes #XXXX in your comment to auto-close the issue that your PR fixes (if such).

Optimox commented 12 months ago

Also it would be perfect if you could squash the commits into one at the end!

CesarLeblanc commented 12 months ago

Oh sorry, it looks like the squashing failed. I received an unprecedented error (not too familiar with Git to be honest). Would you prefer if I looked into it to try to make the 4 commits into 1, or is it ok with you this way?

Optimox commented 11 months ago

I have a question, you always refer to csr_matrix but would everything work with csc_matrix? Is there anything specific to csr here ?

CesarLeblanc commented 11 months ago

@Optimox most of the time people will have a larger number of rows (number of samples) than columns (number of features) in their data. So using the CSR (Compressed Sparse Row) format would generally be more efficient, as it stores the sparse matrix in a row-wise manner, which means that the data is organized by rows (the CSR format thus allows for efficient access to the rows and row-based operations). Nothing prevents the users to use a csc_matrix (the PR that I made allows users to work with both), but I've never seen CSC matrices in deep learning (which doesn't mean that it doesn't exist obviously).