timeseriesAI / tsai

Time series Timeseries Deep Learning Machine Learning Python Pytorch fastai | State-of-the-art Deep Learning library for Time Series and Sequences in Pytorch / fastai
https://timeseriesai.github.io/tsai/
Apache License 2.0
5.19k stars 649 forks source link

assertion error of minibatch length for regression objective #46

Closed geoHeil closed 3 years ago

geoHeil commented 3 years ago

When trying to perform a regression using:

window_length = 120
stride = None
get_x = ['foo', 'bar', 'baz']

get_y = 'y_float'

horizon = 8

X, y = SlidingWindowPanel(window_length, ['id_1', 'id_2'], stride, get_x=get_x, get_y=get_y,  horizon=horizon, seq_first=True, sort_by=['hour'], ascending=True, check_leakage=True, return_key=False, verbose=True)(df)

                        sort_by=['hour'], ascending=True, check_leakage=True, return_key=False, verbose=True)

splits = get_splits(y, valid_size=.2, stratify=True, random_state=47, shuffle=False)
#tfms  = [None, [Categorize()]] << classification works just fine
tfms  = [None, [ToFloat(), ToNumpyTensor()]]
dsets = TSDatasets(X, y, tfms=tfms, splits=splits)

the code fails with:

AssertionError: ==:
2048
16384

However, a classification task with:

%%time

splits = get_splits(y, valid_size=.2, stratify=True, random_state=47, shuffle=False)
tfms  = [None, [Categorize()]]
dsets = TSDatasets(X, y, tfms=tfms, splits=splits)
dsets

and an iteger class label works just fine.

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-9-a480b725f69a> in <module>
     18                    )#.to_fp16()
     19     start = time.time()
---> 20     learn.fit_one_cycle(300, 1e-4)
     21     elapsed = time.time() - start
     22     vals = learn.recorder.values[-1]

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
    110     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    111               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 112     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
    113 
    114 # Cell

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    204             self.opt.set_hypers(lr=self.lr if lr is None else lr)
    205             self.n_epoch = n_epoch
--> 206             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
    207 
    208     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
    195         for epoch in range(self.n_epoch):
    196             self.epoch=epoch
--> 197             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
    198 
    199     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
    190     def _do_epoch(self):
    191         self._do_epoch_train()
--> 192         self._do_epoch_validate()
    193 
    194     def _do_fit(self):

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_validate(self, ds_idx, dl)
    186         if dl is None: dl = self.dls[ds_idx]
    187         self.dl = dl
--> 188         with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
    189 
    190     def _do_epoch(self):

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
    159     def all_batches(self):
    160         self.n_iter = len(self.dl)
--> 161         for o in enumerate(self.dl): self.one_batch(*o)
    162 
    163     def _do_one_batch(self):

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
    177         self.iter = i
    178         self._split(b)
--> 179         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    180 
    181     def _do_epoch_train(self):

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
--> 157         finally:   self(f'after_{event_type}')        ;final()
    158 
    159     def all_batches(self):

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in __call__(self, event_name)
    131     def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]
    132 
--> 133     def __call__(self, event_name): L(event_name).map(self._call_one)
    134 
    135     def _call_one(self, event_name):

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastcore/foundation.py in map(self, f, gen, *args, **kwargs)
    152     def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
    153 
--> 154     def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
    155     def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
    156     def filter(self, f=noop, negate=False, gen=False, **kwargs):

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
    654     res = map(g, iterable)
    655     if gen: return res
--> 656     return list(res)
    657 
    658 # Cell

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
    644             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    645         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 646         return self.func(*fargs, **kwargs)
    647 
    648 # Cell

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in _call_one(self, event_name)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name), event_name
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in <listcomp>(.0)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name), event_name
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/callback/core.py in __call__(self, event_name)
     42                (self.run_valid and not getattr(self, 'training', False)))
     43         res = None
---> 44         if self.run and _run: res = getattr(self, event_name, noop)()
     45         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     46         return res

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in after_batch(self)
    456         if len(self.yb) == 0: return
    457         mets = self._train_mets if self.training else self._valid_mets
--> 458         for met in mets: met.accumulate(self.learn)
    459         if not self.training: return
    460         self.lrs.append(self.opt.hypers[-1]['lr'])

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/learner.py in accumulate(self, learn)
    378     def accumulate(self, learn):
    379         bs = find_bs(learn.yb)
--> 380         self.total += learn.to_detach(self.func(learn.pred, *learn.yb))*bs
    381         self.count += bs
    382     @property

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/metrics.py in accuracy(inp, targ, axis)
     99 def accuracy(inp, targ, axis=-1):
    100     "Compute accuracy with `targ` when `pred` is bs * n_classes"
--> 101     pred,targ = flatten_check(inp.argmax(dim=axis), targ)
    102     return (pred == targ).float().mean()
    103 

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastai/torch_core.py in flatten_check(inp, targ)
    781     "Check that `out` and `targ` have the same number of elements and flatten them."
    782     inp,targ = inp.contiguous().view(-1),targ.contiguous().view(-1)
--> 783     test_eq(len(inp), len(targ))
    784     return inp,targ

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastcore/test.py in test_eq(a, b)
     33 def test_eq(a,b):
     34     "`test` that `a==b`"
---> 35     test(a,b,equals, '==')
     36 
     37 # Cell

~/development/conda_envs/my_env/lib/python3.8/site-packages/fastcore/test.py in test(a, b, cmp, cname)
     23     "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails"
     24     if cname is None: cname=cmp.__name__
---> 25     assert cmp(a,b),f"{cname}:\n{a}\n{b}"
     26 
     27 # Cell

AssertionError: ==:
2048
16384
oguiza commented 3 years ago

Hi @geoHeil,

A couple of comments:

  1. I don't know how you are building the model, but it seems you are passing just c_out = 1. I'd recommend you used build_ts_model(arch, dls=dls) as that will automatically set the right parameters to build the model. In your case for example, c_out needs to be the same as the horizon.
  2. Based on the error trace, you seem to be using accuracy in a regression task. You should instead consider using some regression metrics like mse, mae, etc.

As an example, here's some code you may use to run a multi-output regression task:

dsid = 'NATOPS' 
X, y, splits = get_UCR_data(dsid, split_data=False)
tfms  = [None, TSRegression()]
dls = get_ts_dls(X, np.random.rand(y.shape[0], 8), tfms=tfms, splits=splits, bs=[64, 128])
learn = ts_learner(dls, InceptionTimePlus, metrics=[mae, mse])
learn.fit_one_cycle(1)

BTW, this approach is also valid for some models (those finished in Plus, as they have a custom_head kwarg) even if you use a 3d target y:

dsid = 'NATOPS' 
X, y, splits = get_UCR_data(dsid, split_data=False)
tfms  = [None, TSRegression()]
dls = get_ts_dls(X, np.random.rand(y.shape[0], 3, 8), tfms=tfms, splits=splits, bs=[64, 128])
learn = ts_learner(dls, InceptionTimePlus, metrics=[mae, mse])
learn.fit_one_cycle(5)

This may be useful in multivariate, multi-step forecasting or 2d output regression.

geoHeil commented 3 years ago

many thanks!