dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.6k stars 482 forks source link

sampler is not needed in pretrain mode for valid dataloader #499

Open cyang31 opened 1 year ago

cyang31 commented 1 year ago

Describe the bug

What is the current behavior? The current validation dataloader in pretraining_utils.py takes the sampler generated from the X_train, which causes errors if X_train and X_valid in eval_set don't have the exact size.

If the current behavior is a bug, please provide the steps to reproduce. Here's a script to reproduce the issue, where weights is assigned as a ndarray.

from pytorch_tabnet.pretraining import TabNetPretrainer
import numpy as np
import torch

# Set the random seed for reproducibility
np.random.seed(42)

# Generate random features
num_train_samples = 100000
num_valid_samples = 50000
num_features = 10

X_train = np.random.rand(num_train_samples, num_features)
X_valid = np.random.rand(num_valid_samples, num_features)

# Generate random binary labels
y_train = np.random.randint(2, size=num_train_samples)
y_valid = np.random.randint(2, size=num_valid_samples)

num_positive_samples = np.sum(y_train)
num_negative_samples = len(y_train)-num_positive_samples
class_weights=np.zeros(len(y_train))

class_weights[y_train==0] = 1/num_negative_samples
class_weights[y_train==1] = 1/num_positive_samples

# TabNetPretrainer
unsupervised_model = TabNetPretrainer(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    mask_type='entmax', # "sparsemax",,
    device_name='cpu'
)

unsupervised_model.fit(
    X_train=X_train,
    eval_set=[X_valid],
    pretraining_ratio=0.5,
    weights=class_weights
)

Expected behavior Now the above script returns an error IndexError: index 94028 is out of bounds for axis 0 with size 50000, which suggests that the weights is also applied to the X_valid which is not necessary.

Screenshots

Other relevant information: poetry version:
python version: 3.8 Operating System: linux, macos Additional tools:

Additional context