Open EelcoHoogendoorn opened 1 year ago
The version in the torch code is the one used in all experiments in the original paper. Any differences from the paper figure is probably due to a different interpretation. To be honest, the RNN cell was somewhat arbitrary so there are a lot of reasonable alternatives. The original HiPPO-RNN cell has long been abandoned in favor of the S4 approach.
Hi all,
Im working on a JAX implementation of a hippo-gated-rnn; I wasnt quite sure how to interpret the diagram in the paper linked below; but indeed I cannot quite mesh it with the torch implementation linked below as well. The code seems more sensible to me than the paper; that is it makes sense to me to have the ssm see the raw data unfiltered; placing the nonlinear gated action in front seems like it might jeopardize the unobstructed flow of gradients along the sequence.
The version from the diagram in the paper works quite alright though, in my use case. Though the torch version seems to converge more quickly. Just curious if im misreading something here, or what your latest thinking on these matters is.
EDIT: I had very good experience with the paper version in terms of avoiding exploding gradients; while the code version seems to converge faster and smoother initially, I do observe the gated unit to be able to explode on longer trajectories. Lots of things to explore here I suppose; deep/stacked ssms with pointwise nonlinearities have not worked for me so far.