Open strakehyr opened 2 years ago
Hello. I am not a maintainer of the repository, but my PRs #365 and #362 are currently awaiting approval so I am getting familiar with the codebase. From https://github.com/timeseriesAI/tsai/blob/main/tsai/models/utils.py#L177 and https://github.com/timeseriesAI/tsai/blob/main/tsai/tslearner.py#L50 we can see that build_tabular_model
and TSLearners are customized only for the TabularModel and not the TabTransformer. I am planning on opening a new pull request to improve the support for TabTransformer (and GatedTabTransformer) after my current PRs are finished. In the meantime you can support them with thumbs up :)
Hi @strakehyr and @radi-cho, That's a good point. It'd be good to have TabTransformer supported by the build_tabular_model. build_tabular_model is just a convenience function. The workaround is to use this:
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names="salary",
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],
cont_names = ['age', 'fnlwgt', 'education-num'],
procs = [Categorify, FillMissing, Normalize])
model = TabTransformer(dls.classes, dls.cont_names, dls.c)
learn = Learner(dls, model)
learn.fit_one_cycle(1)
Testing the build_tabular_model, I tried feeding it with TabTransformer and seemed to not be able to generate a model.
tab = build_tabular_model(TabTransformer, dls = dls_cat)
The error I got was:TypeError: __init__() got an unexpected keyword argument 'y_range'