jina-ai / GSoC

Google Summer of Code
65 stars 11 forks source link

JAX support in DocArray v2 #21

Open Nick17t opened 1 year ago

Nick17t commented 1 year ago

Project idea 6: JAX support in DocArray v2

Info details
Skills needed Python, deep learning , JAX
Project size 175 hours
Difficulty level Hard
Mentors @Sami Jaghouar

Project Description

Expected outcomes

Desired skills

More detailed :

This Project target DocArray, especially the current rewrite: DocArray v2 which is a new codebase.

We currently support three computational frameworks in DocArray v2 : Pytorch, Numpy, and TensorFlow, we would like to add JAX support.

More info about JAX can be found here but in short, it is a deep learning framework supported by Google that is getting a lot of traction, especially among researchers.

Concretely what is expected in this project:

DevPranjal commented 1 year ago

Hello @Nick17t @samsja! I am Pranjal.

While surfing GSoC projects, I came across this today. Having multi-modal data structures compatible with JAX modules sounds really cool to me. I had a small go through the DocArray codebase and found ArrayType and AnyDNN as the framework agnostic types. I believe their uses in codebase such as in .embed() and docarray.math.distance will need to be looked into for the JAX port. We will also need to decide on the use of either Flax or Haiku Modules for DNNs. Overall, the project seems very exciting to work on!

I would love to know more and contribute to the project.

samsja commented 1 year ago

@DevPranjal I added more info in the description of the issue. Be aware that this project is on DocArray v2

tehami02 commented 1 year ago

@Nick17t @samsja Based on the given information, here is a what I understood:

Project Description:

DocArray is a library for representing, sending, and storing multi-modal data, with a focus on applications in ML and Neural Search. It currently supports PyTorch, Numpy, and TensorFlow as computational backends. We want to extend the backend support to include JAX. The project goal is to add JAX support as a computational backend in DocArray v2.

Here are the specific tasks involved:

  1. Add a new backend to the computational backend while relying as much as possible on JAX Numpy (jnp) as a numpy-like interface for JAX.

  2. Create a new Tensor object with the JAX backend, including variants for ImageTensor and other tensor types.

  3. Ensure compatibility of DocumentArrayStack with JAX, with unit testing for each function in the computational backend.

  4. Thoroughly test the implementation through the following:

Unit tests for each function in the computational backend, using predefined tensors and DocumentArrayStack. Integration tests to check the coherence of the entire implementation, with emphasis on training a small neural network using DocArray + JAX.

Expected outcomes:

Upon successful completion of this project, DocArray v2 will support JAX as a computational backend alongside PyTorch, Numpy, and TensorFlow. The implementation will be thoroughly tested and documented.

I would like to work on this project.

Arnav131003 commented 1 year ago

Hey @samsja , I would like to contribute to the project.Please guide me how to get started with the stuff. I am proficient in python as well as machine learning using tensorflow

Arnav131003 commented 1 year ago

Hi @samsja @Nick17t , as much as I understood , i tried doing it.Please state if I am on the correct path

Screenshot 2023-03-13 at 5 08 16 PM
Nick17t commented 1 year ago

Hi @DevPranjal @Arnav131003 @tehami02

I am delighted to hear that you are interested in contributing to the Jina AI community! 🎉

To get started, please take a moment to fill out our survey so that we can learn more about you and your skills.

Also, don't forget to mark your calendars for the GSoC x Jina AI webinar on March 23rd at 2 pm (CET). This is an excellent opportunity to learn more about the projects and ask any questions you have about the requirements and expectations.

Our mentors will provide an in-depth overview of the projects and answer any questions you may have. So please don't hesitate to ask any questions or seek clarification on any aspect of the project.

Is there anything specific you would like to learn from the webinar? Do you have any questions about the JAX support in DocArray v2 project that you would like to see clarified during the Q&A session? Let me know, and I'll be happy to help!

Looking forward to seeing you at the webinar, and thank you for your interest in the Jina AI community! 😊

Lancelot03 commented 1 year ago

Hi @Nick17t this is very interesting project and I have worked on similar kind of project where we have to create the new backend module for JAX. And to make DocumentArrayStack compatible with JAX we need to ensure that DocumentArrayStack works seamlessly with the Jax backend. This will involve testing the existing DocumentArrayStack code with the new Jax backend and resolving any compatibility issues that arise. And I love to work on this project 😁.