jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.26k stars 2.77k forks source link

[GSoC 2021 Proposal] Few shot algorithms in JAX #5802

Closed INF800 closed 3 years ago

INF800 commented 3 years ago

Hi!

I want to implement an end-to-end framework for few/one shot learning algorithms using jax. This is will be a long term project but currently, I am planning to implement the following as a part of GSoC 2021

Any references to how to proceed will be immensely helpful.

Cheers! Rakesh

mattjj commented 3 years ago

Hey @rakesh4real, that sounds really fun!

I suggest checking out the Flax or Haiku libraries, as you'll likely want to use more than just core JAX to build and train these kinds of models.

Hope your project goes well!