By default, valid_batch_size = 1024 in BaseModel which gives an issue with BaseMultivariate if n_series > 1024. This fix simply sets the valid_batch_size =n_series, similar to how batch_size is implemented in BaseMultivariate.
Summary:
Fixes #948
Added unit test to TSMixerx model that fails without the fix (unit test should also work for other multivariate models, I only included it for TSMixerx though, since it's more related to BaseMultivariate than a specific model)
By default,
valid_batch_size = 1024
inBaseModel
which gives an issue withBaseMultivariate
ifn_series > 1024
. This fix simply sets thevalid_batch_size =n_series
, similar to howbatch_size
is implemented inBaseMultivariate
.Summary: