Closed mskiki777 closed 1 year ago
I have this same issue when trying to use load_data() on separate test and train dataframes. Weirdly, one of the two tabular_torch_dataset.TorchTextDataset` returned from load_data will train; the other will not.
I have to use the code with the load from file and setup just as in the colab to make work.
@codeKgu Seems like you put a ton of work into this repo. Would be great to get this fixed.
@petulla and @mikiwz - I think I found the reason for this issue. In this line: https://github.com/georgian-io/Multimodal-Toolkit/blob/master/multimodal_transformers/data/load_data.py#L228
The package concatenates the train, val, and test dfs. Then, if you're precessing the categorical features via one hot encoding, which is the default, it will one hot encode with ALL of those dfs together.
For example, say your train df has a categorical feature with values ["a", "b"]. This would get one hot encoded as 2 separate columns (a and b). However, say your test data has values ["a", "c"]. Well, with the way this is currently packaged, the train and test data is concatenated together and so there will be one hot encoding to produce 3 columns (a, b, and c). But, if you load your test dataset separately, you would only one hot encode "a" and "c" - resulting in 2 columns instead of 3. This is the issue. The model was thus trained on 3 columns, but you're giving it 2 columns to predict with.
The way around this is to either not use categorical data, or use label encoding instead:
test_dataset_2 = load_data(
test_data,
data_args.column_info['text_cols'],
tokenizer,
label_col=data_args.column_info['label_col'],
label_list=data_args.column_info['label_list'],
numerical_cols=data_args.column_info['num_cols'],
sep_text_token_str=tokenizer.sep_token,
categorical_encode_type="label"
)
Closing as this has been answered.
I was trying to reproduce the example notebook https://colab.research.google.com/github/georgianpartners/Multimodal-Toolkit/blob/master/notebooks/text_w_tabular_classification.ipynb#scrollTo=ABT1hK9cRsuk and got the error
RuntimeError Traceback (most recent call last)