mjo22 / cryojax

Cryo electron microscopy image simulation and analysis built on JAX.
https://mjo22.github.io/cryojax/
GNU Lesser General Public License v2.1
29 stars 9 forks source link

Change recommendation for how to instantiate pytrees with vmap #202

Closed mjo22 closed 6 months ago

mjo22 commented 6 months ago

equinox.internal.if_mapped changes everything! Specifying out_axes with this in a call to eqx.filter_vmap makes it so non-mapped pytree leaves do not get broadcasted. See the simulate-micrograph.ipynb for an example and description of this.