timeseriesAI / tsai

Time series Timeseries Deep Learning Machine Learning Python Pytorch fastai | State-of-the-art Deep Learning library for Time Series and Sequences in Pytorch / fastai
https://timeseriesai.github.io/tsai/
Apache License 2.0
5.25k stars 656 forks source link

Does build_tabular_model accept TabTransformer? #358

Open strakehyr opened 2 years ago

strakehyr commented 2 years ago

Testing the build_tabular_model, I tried feeding it with TabTransformer and seemed to not be able to generate a model. tab = build_tabular_model(TabTransformer, dls = dls_cat) The error I got was: TypeError: __init__() got an unexpected keyword argument 'y_range'

radi-cho commented 2 years ago

Hello. I am not a maintainer of the repository, but my PRs #365 and #362 are currently awaiting approval so I am getting familiar with the codebase. From https://github.com/timeseriesAI/tsai/blob/main/tsai/models/utils.py#L177 and https://github.com/timeseriesAI/tsai/blob/main/tsai/tslearner.py#L50 we can see that build_tabular_model and TSLearners are customized only for the TabularModel and not the TabTransformer. I am planning on opening a new pull request to improve the support for TabTransformer (and GatedTabTransformer) after my current PRs are finished. In the meantime you can support them with thumbs up :)

oguiza commented 2 years ago

Hi @strakehyr and @radi-cho, That's a good point. It'd be good to have TabTransformer supported by the build_tabular_model. build_tabular_model is just a convenience function. The workaround is to use this:

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names="salary",
    cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],
    cont_names = ['age', 'fnlwgt', 'education-num'],
    procs = [Categorify, FillMissing, Normalize])
model = TabTransformer(dls.classes, dls.cont_names, dls.c)
learn = Learner(dls, model)
learn.fit_one_cycle(1)