Open reemabdelrazek30 opened 4 months ago
Is there any workaround that I can perform in order to use jax.vmap with hk.BatchNorm. should I use hk.vmap instead? should I write a custom batchNorm?
Is there any workaround that I can perform in order to use jax.vmap with hk.BatchNorm. should I use hk.vmap instead? should I write a custom batchNorm?