What is the current behavior?
The current validation dataloader in pretraining_utils.py takes the sampler generated from the X_train, which causes errors if X_train and X_valid in eval_set don't have the exact size.
If the current behavior is a bug, please provide the steps to reproduce.
Here's a script to reproduce the issue, where weights is assigned as a ndarray.
from pytorch_tabnet.pretraining import TabNetPretrainer
import numpy as np
import torch
# Set the random seed for reproducibility
np.random.seed(42)
# Generate random features
num_train_samples = 100000
num_valid_samples = 50000
num_features = 10
X_train = np.random.rand(num_train_samples, num_features)
X_valid = np.random.rand(num_valid_samples, num_features)
# Generate random binary labels
y_train = np.random.randint(2, size=num_train_samples)
y_valid = np.random.randint(2, size=num_valid_samples)
num_positive_samples = np.sum(y_train)
num_negative_samples = len(y_train)-num_positive_samples
class_weights=np.zeros(len(y_train))
class_weights[y_train==0] = 1/num_negative_samples
class_weights[y_train==1] = 1/num_positive_samples
# TabNetPretrainer
unsupervised_model = TabNetPretrainer(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
mask_type='entmax', # "sparsemax",,
device_name='cpu'
)
unsupervised_model.fit(
X_train=X_train,
eval_set=[X_valid],
pretraining_ratio=0.5,
weights=class_weights
)
Expected behavior
Now the above script returns an error IndexError: index 94028 is out of bounds for axis 0 with size 50000, which suggests that the weights is also applied to the X_valid which is not necessary.
Describe the bug
What is the current behavior? The current validation dataloader in
pretraining_utils.py
takes the sampler generated from theX_train
, which causes errors ifX_train
andX_valid
ineval_set
don't have the exact size.If the current behavior is a bug, please provide the steps to reproduce. Here's a script to reproduce the issue, where
weights
is assigned as a ndarray.Expected behavior Now the above script returns an error
IndexError: index 94028 is out of bounds for axis 0 with size 50000
, which suggests that the weights is also applied to the X_valid which is not necessary.Screenshots
Other relevant information: poetry version:
python version: 3.8 Operating System: linux, macos Additional tools:
Additional context