Closed zhixuan-lin closed 7 months ago
Hi, thanks for the question and feedback! This variance does seem much higher than expected. After refactoring the code for the public release, I think I did observe this occasionally for both S5 and a JAX S4D implementation when tested on some hardware, but it was pretty rare. But it is possible some bug did get introduced during the refactor.
Yes good catch and thanks for pointing this out. The intention was to include these parameters in the SSM parameters. However, it appears the results in the paper were computed using the current setup (where the norm parameters are not actually included in the SSM parameters), so I wouldn't expect this to have a major effect. But perhaps it could help?
The broadcast of dropout keeps the same dropout pattern across the sequence. This was a copy of the annotated S4 codebase implementation here: https://github.com/srush/annotated-s4/blob/f19e2464990f0c07943bb922894d6f2af65c9bd9/s4/s4.py#L515 and I believe is how the official Torch S4 implementation included dropout as well, though I would have to dig through that codebase to confirm this.
Thanks for your reply! Do you plan to investigate what the potential bug is? I think it is easy to reproduce my observation since I didn't change anything that could affect results (except for package versions, but I don't think this matters).
I just noticed that in S4D (Table 4) the initialization range for the timescale for PathX is [0.0001, 0.01], yet in both your paper (Appendix B.1.3) and the released code (no dt_max
specified here) you used [0.0001, 0.1]. Is there any chance that you actually used [0.0001, 0.01] in your private codebase?
Btw on Pathfinder (not PathX) out of 10 seeds there will be like 1 or 2 failed runs, but I guess this is fine
It seems we did use the current defaults and reported valued of [0.0001,0.1].
We can look into it further, but do not have a specific timeline.
Got it, thanks!
Hi,
First thanks for this awesome work and repo :) I'm trying to reproduce the results on PathX using this codebase, and I'm noticing high variance across seeds. I ran 10 different seeds (jax_seed=[0, 1, 2, 3, 4 ,5, 6, 7, 8, 9]). 6 of them worked as expected but 4 of them failed to learn anything. It seems that in your paper you take the average over 3 seeds, so I think this level of variance is strange. Any idea what could be wrong? Or is this level of variance expected?
Also, I have some questions about the code;
map_nested_fn
only acts on leave nodes in the tree. Is this a bug? Which one is the intended behavior?broadcast_dims=[0]
in dropout, for example here?Thanks in advance!