google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

hk.BatchNorm with jax.vmap #786

Open reemabdelrazek30 opened 4 months ago

reemabdelrazek30 commented 4 months ago

Is there any workaround that I can perform in order to use jax.vmap with hk.BatchNorm. should I use hk.vmap instead? should I write a custom batchNorm?