fastai / fastai

The fastai deep learning library
http://docs.fast.ai
Apache License 2.0
26.17k stars 7.55k forks source link

Weighted data loading of tabular data #3860

Open minkooseo opened 1 year ago

minkooseo commented 1 year ago

Link to forum discussion. [1] https://forums.fast.ai/t/weighted-dataloaders-with-tabularpandas/102057/2 [2] https://forums.fast.ai/t/using-weighteddl-with-tabulardata/80789

I don't find anyone successfully using tabular data with sample weights.

Is your feature request related to a problem? Please describe. It's necessary to handle class imbalance in tabular data.

Describe the solution you'd like Most straightforward solution would be adding wgts to TabularDataLoaders: https://docs.fast.ai/tabular.data.html

I tried to use TabularPandas in combination with calling dataloaders(dl_type=WeightedDL, bs=64, wgts=...), but it leads to an error described in [2]. I tried to fix it myself but it's too complex for me. I observed two issues:

Describe alternatives you've considered Instead of using WeightedDL, I can try sampling my data frame before calling dataloader. But that's going to be very cumbersome.

Additional context

Steps to reproduce:

from fastai.tabular.all import *

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv', skipinitialspace=True)
df.head()
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]

to = TabularPandas(
    df, procs=procs,
    cat_names=cat_names, cont_names=cont_names, y_names='race', 
    y_block = CategoryBlock())

# works ok
dls = to.dataloaders(bs=20)
dls.show_batch()

# fails.
# "Could not do one pass in your dataloader, there is something wrong in it. Please see the stack trace below"
dls = to.dataloaders(bs=20, dl_type=WeightedDL, wgts=[1.0] * len(df))
myenugula commented 9 months ago

I have solved this issue with PR #3995. Now you can do the following and it will run successfully: dls = to.dataloaders(bs=20, dl_type=TabWeightedDL, wgts=[1.0] * len(df))