danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
101 stars 14 forks source link

AutoregressiveBisectionInverter bounds picked up as trainable parameters #173

Closed mdmould closed 1 month ago

mdmould commented 1 month ago

Because the lower and upper bounds for the inverter are arrays, they get picked up as trainable parameters when filtering with, e.g., is_inexact_array. I think this isn't desired behaviour in any cases, because it would probably interact unexpectedly with the adaptive bounding in the bisection search. Usually this won't affect anything because a flow that contains the inverter, e.g., BNAF, is trained only without ever using the numerical inverse. The only reason I noticed this was because I was counting the number of parameters I expected in the model.

I'm not sure if this is really an issue in practice. In any case, one can just wrap the inverter in a non_trainable or manually ignore the lower and upper "parameters" as necessary. But maybe there's a neater solution that keeps the bisection functions compatible with the scans and so on?

danielward27 commented 1 month ago

Good spot, thanks. Your absolutely right that they shouldn't be included, I'll get that fixed. Like you said generally it shouldn't matter (I believe JAX will error if you try to differentiate through the numerical method), but one case where I believe it matters is if regularisation of parameters is used.