kathrinse / TabSurvey

Experiments on Tabular Data Models
MIT License
265 stars 60 forks source link

AssertionError: you must pass in 0 values for your categories input #9

Closed sonnguyen129 closed 1 year ago

sonnguyen129 commented 1 year ago

Hi, @kathrinse. When I tried to train TabTransformer, I Cause this error, seem the args.cat_dims is None thought I passed cat_dims in my config

[]   # args.cat_dims
On Device: cuda
Using dim 128 and batch size 10240
On Device: cuda
[W 2022-10-24 22:25:20,624] Trial 10 failed because of the following error: AssertionError('you must pass in 0 values for your categories input')
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/optuna/study/_optimize.py", line 196, in _run_trial
    value_or_values = func(trial)
  File "train.py", line 117, in __call__
    sc, time = cross_validation(model, self.X, self.y, self.args)
  File "train.py", line 60, in cross_validation
    loss_history, val_loss_history = curr_model.fit(X_train, y_train, X_test, y_test)  # X_val, y_val)
  File "/content/drive/MyDrive/Code/TabSurvey-main/models/tabtransformer.py", line 103, in fit
    out = self.model(x_categ, x_cont)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/drive/MyDrive/Code/TabSurvey-main/models/tabtransformer.py", line 463, in forward
    assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} ' \
AssertionError: you must pass in 0 values for your categories input

Here is my config

# General parameters
dataset: A
model_name: TabTransformer # LinearModel, KNN, SVM, DecisionTree, RandomForest, XGBoost, CatBoost, LightGBM, ModelTree
                # MLP, TabNet, VIME, TabTransformer, RLN, DNFNet, STG, NAM, DeepFM, SAINT
objective: classification # Don't change
# optimize_hyperparameters: True

# Preprocessing parameters
scale: True
target_encode: True
one_hot_encode: False

# Training parameters
batch_size: 10240
val_batch_size: 2560
early_stopping_rounds: 100
epochs: 2
logging_period: 100

# About the data
num_classes: 4  # for classification
num_features: 18
cat_idxs: [0, 1, 2, 3, 4, 14, 15, 16, 17]
# cat_dims: will be automatically set.
cat_dims: [10, 2, 7, 4, 6, 4, 4, 4, 4]

Hope to hear from you soon. Thanks in advanced.

sonnguyen129 commented 1 year ago

I think I solve this error. I will reopen this issue when I have more questions