HazyResearch / hippo-code

Apache License 2.0
168 stars 30 forks source link

How to handle statefulness in an online HiPPO implementation? #7

Closed farshchian closed 3 years ago

farshchian commented 3 years ago

I am training HiPPO models on times series where the model predicts an output for every time step of the input sequence. During training, I reset the states at the beginning of each batch, while at test time I reset the state only at time zero. While, this is a common practice in traditional recurrent networks like LSTMs and GRUs, I am getting poor test performance with HiPPO models (LagT, LegT, and LegS). The performance improves if I reset the states at the exact sequence length that I used during training, however this has the downside of erasing the memory. I was wondering what would be your recommendation on how to maintain statefulness for the three elements of the state (i.e. h, memory and time_step) during online, real-time predictions?

albertfgu commented 3 years ago

This is an interesting phenomenon. I think we did not ever train in settings where the test sequences were much longer than training sequences, so we didn't encounter this. One note is that I would expect LegS to not work well in this setting (because it is not "stationary" in a sense, so its behavior will be quite different on longer sequences), but I would have expected LegT to be ok.

A few questions about the setting: What are the lengths of the training and test sequences? Are all train sequences the same length and the testing sequences are much longer? Is it a setting where you have very long sequences and are chunking them during training; if this is the case, can states also be passed between batches during training instead of being reset?

farshchian commented 3 years ago

Thanks for the prompt response! As you correctly pointed out, LegS did not work well, and the best performance was indeed achieved with LegT. To answer your questions, the data consists of long time series that we are chunking during training. Keeping the state alive between batches would be difficult due to large variations in the length of datasets and random discontinuities within each data set. Perhaps, a possible intermediate solution would be the increase the sequence length and use TBPTT.

albertfgu commented 3 years ago

Yeah, that may be a good approach. This is an interesting problem to explore!