Closed LGro closed 5 months ago
Hello @LGro,
Thank you for reporting this bug. I think this comes from the tolerances used in FreeFormJacobianTransform
which are way higher in LAMPE/Zuko (1e-5
) than in the original implementation (1e-7
).
Could you try to modify the atol
and rtol
in the FreeFormJacobianTransform
of log_prob
and repeat your experiments?
Also it might be worth running in double precision (float64
).
Thanks for digging into this issue with me :relaxed:
Indeed, shrinking the tolerances while running with the estimator and inputs at float64 precision does reduce the initially observed discrepancy. Do I understand it right that the discrepancy is not problematic per-se as long as the magnitude is irrelevant for one's application?
does reduce the initially observed discrepancy
It does not vanish with both absolute and relative tolerances at 1e-7
?
Do I understand it right that the discrepancy is not problematic per-se as long as the magnitude is irrelevant for one's application?
Yes the discrepancy is not an implementation or method issue, but a numerical issue. If it is small enough, it should not affect downstream tasks. It could be worth adding the option to modify the tolerances in the FMPE
class though, or maybe a warning in the doc-string.
It does not vanish with both absolute and relative tolerances at
1e-7
?
For tolerances at 1e-9
the differences go down to the order of 1e-5
or 1e-6
, which was enough of an indicator for me. I have not tried to push it to the limit of float64
precision.
How does this compare to the official implementation (at 1e-7
)? If at the same tolerance the official implementation shows less discrepancies between batched/unbatched, it could be worth investigating further.
Description
When computing the log probability with FMPE's log_prob method, the resulting probability values depend on the other input elements in the batch. The change I saw was in the order of the third or fourth decimal place.
In any case, thanks already a lot for your work on LAMPE :relaxed:
Reproduce
Following the example, the two ways to compute log probabilities for a given configuration
theta
and batch of corresponding simulated resultsx
produce different results:Expected behavior
I would expect that the individual log probability values for one
theta
andx
pair are not affected by the other entries in theX
batch. This is corroborated by the official implementation not showing that behaviour when evaluatinglog_prob_batch
with different subsets for the batch.In the above example, I would expect both to e.g. result in
[3.1978, 1.8175, 2.4526, 1.6468, 3.0495, 2.5894, 2.7065, 2.7712, 3.3385, 1.5558]
.Causes and solution
I have no clear intuition why that would be the case. I suspected a stochastic influence and that the
FreeFormJacobianTransform
exact mode might help, but it seems to be a deterministic difference and settingsexact=true
did not affect that accordingly. I noticed that the LAMPE implementation utilizes a trigonometrical embedding of the time dimension for the vector field computation when the official implementation by the authors does not, but it's also not obvious to me that this would explain the difference.Environment