lindermanlab / S5

MIT License
259 stars 45 forks source link

Reproducing Results on PathX #10

Closed zhixuan-lin closed 7 months ago

zhixuan-lin commented 8 months ago

Hi,

First thanks for this awesome work and repo :) I'm trying to reproduce the results on PathX using this codebase, and I'm noticing high variance across seeds. I ran 10 different seeds (jax_seed=[0, 1, 2, 3, 4 ,5, 6, 7, 8, 9]). 6 of them worked as expected but 4 of them failed to learn anything. It seems that in your paper you take the average over 3 seeds, so I think this level of variance is strange. Any idea what could be wrong? Or is this level of variance expected?

Also, I have some questions about the code;

  1. Here it seems that you intend to include LayerNorm/BatchNorm parameters in SSM parameters. However, this has no effect since map_nested_fn only acts on leave nodes in the tree. Is this a bug? Which one is the intended behavior?
  2. Is there a reason you use broadcast_dims=[0] in dropout, for example here?

Thanks in advance!

jimmysmith1919 commented 8 months ago

Hi, thanks for the question and feedback! This variance does seem much higher than expected. After refactoring the code for the public release, I think I did observe this occasionally for both S5 and a JAX S4D implementation when tested on some hardware, but it was pretty rare. But it is possible some bug did get introduced during the refactor.

  1. Yes good catch and thanks for pointing this out. The intention was to include these parameters in the SSM parameters. However, it appears the results in the paper were computed using the current setup (where the norm parameters are not actually included in the SSM parameters), so I wouldn't expect this to have a major effect. But perhaps it could help?

  2. The broadcast of dropout keeps the same dropout pattern across the sequence. This was a copy of the annotated S4 codebase implementation here: https://github.com/srush/annotated-s4/blob/f19e2464990f0c07943bb922894d6f2af65c9bd9/s4/s4.py#L515 and I believe is how the official Torch S4 implementation included dropout as well, though I would have to dig through that codebase to confirm this.

zhixuan-lin commented 8 months ago

Thanks for your reply! Do you plan to investigate what the potential bug is? I think it is easy to reproduce my observation since I didn't change anything that could affect results (except for package versions, but I don't think this matters).

zhixuan-lin commented 8 months ago

I just noticed that in S4D (Table 4) the initialization range for the timescale for PathX is [0.0001, 0.01], yet in both your paper (Appendix B.1.3) and the released code (no dt_max specified here) you used [0.0001, 0.1]. Is there any chance that you actually used [0.0001, 0.01] in your private codebase?

Btw on Pathfinder (not PathX) out of 10 seeds there will be like 1 or 2 failed runs, but I guess this is fine

jimmysmith1919 commented 8 months ago

It seems we did use the current defaults and reported valued of [0.0001,0.1].

We can look into it further, but do not have a specific timeline.

zhixuan-lin commented 8 months ago

Got it, thanks!