juliuskunze / jaxnet

Concise deep learning for JAX
Apache License 2.0
184 stars 14 forks source link

Allow vmap(parametrized(fun)) #16

Closed juliuskunze closed 4 years ago

juliuskunze commented 4 years ago

Add a batching rule to the parametrized primitive.

juliuskunze commented 4 years ago

Use case now covered by Batched (https://github.com/JuliusKunze/jaxnet/issues/20).