Open Nick17t opened 1 year ago
Hello @Nick17t @samsja! I am Pranjal.
While surfing GSoC projects, I came across this today. Having multi-modal data structures compatible with JAX modules sounds really cool to me. I had a small go through the DocArray
codebase and found ArrayType
and AnyDNN
as the framework agnostic types. I believe their uses in codebase such as in .embed()
and docarray.math.distance
will need to be looked into for the JAX port. We will also need to decide on the use of either Flax or Haiku Modules for DNNs. Overall, the project seems very exciting to work on!
I would love to know more and contribute to the project.
@DevPranjal I added more info in the description of the issue. Be aware that this project is on DocArray v2
@Nick17t @samsja Based on the given information, here is a what I understood:
DocArray is a library for representing, sending, and storing multi-modal data, with a focus on applications in ML and Neural Search. It currently supports PyTorch, Numpy, and TensorFlow as computational backends. We want to extend the backend support to include JAX. The project goal is to add JAX support as a computational backend in DocArray v2.
Here are the specific tasks involved:
Add a new backend to the computational backend while relying as much as possible on JAX Numpy (jnp) as a numpy-like interface for JAX.
Create a new Tensor object with the JAX backend, including variants for ImageTensor and other tensor types.
Ensure compatibility of DocumentArrayStack with JAX, with unit testing for each function in the computational backend.
Thoroughly test the implementation through the following:
Unit tests for each function in the computational backend, using predefined tensors and DocumentArrayStack. Integration tests to check the coherence of the entire implementation, with emphasis on training a small neural network using DocArray + JAX.
Upon successful completion of this project, DocArray v2 will support JAX as a computational backend alongside PyTorch, Numpy, and TensorFlow. The implementation will be thoroughly tested and documented.
I would like to work on this project.
Hey @samsja , I would like to contribute to the project.Please guide me how to get started with the stuff. I am proficient in python as well as machine learning using tensorflow
Hi @samsja @Nick17t , as much as I understood , i tried doing it.Please state if I am on the correct path
Hi @DevPranjal @Arnav131003 @tehami02
I am delighted to hear that you are interested in contributing to the Jina AI community! 🎉
To get started, please take a moment to fill out our survey so that we can learn more about you and your skills.
Also, don't forget to mark your calendars for the GSoC x Jina AI webinar on March 23rd at 2 pm (CET). This is an excellent opportunity to learn more about the projects and ask any questions you have about the requirements and expectations.
Our mentors will provide an in-depth overview of the projects and answer any questions you may have. So please don't hesitate to ask any questions or seek clarification on any aspect of the project.
Is there anything specific you would like to learn from the webinar? Do you have any questions about the JAX support in DocArray v2 project that you would like to see clarified during the Q&A session? Let me know, and I'll be happy to help!
Looking forward to seeing you at the webinar, and thank you for your interest in the Jina AI community! 😊
Hi @Nick17t this is very interesting project and I have worked on similar kind of project where we have to create the new backend module for JAX. And to make DocumentArrayStack compatible with JAX we need to ensure that DocumentArrayStack works seamlessly with the Jax backend. This will involve testing the existing DocumentArrayStack code with the new Jax backend and resolving any compatibility issues that arise. And I love to work on this project 😁.
Project idea 6: JAX support in DocArray v2
Project Description
DocArray is a library for representing, sending, and storing multi-modal data, with a focus on applications in ML and Neural Search. It currently supports several deep learning frameworks, including PyTorch and TensorFlow. Jax is becoming increasingly popular for deep learning, so we want to integrate it into DocArray.
The project we propose is to add Jax as a backend for DocArray, alongside PyTorch and TensorFlow. The first part would involve rewriting and translating all of the computational backend functions of DocArray with the Jax framework. Then, we would battle-test the implementation against a real Jax use case, such as integrating DocArray with Jax support for model training and serving.
Expected outcomes
Desired skills
More detailed :
This Project target DocArray, especially the current rewrite: DocArray v2 which is a new codebase.
We currently support three computational frameworks in DocArray v2 : Pytorch, Numpy, and TensorFlow, we would like to add JAX support.
More info about JAX can be found here but in short, it is a deep learning framework supported by Google that is getting a lot of traction, especially among researchers.
Concretely what is expected in this project:
Add a new backend to our Computational Backend while relying as much as possible on
jnp
(Jax Numpy) which is a numpy life interface for JAX. A similar approach can be found for the TensorFlow backend: https://github.com/docarray/docarray/blob/feat-rewrite-v2/docarray/computation/tensorflow_backend.pyCreate a new Tensor object with the JAX backend: Example : ImageTensor will need a JAX variant (all of the other one as well)
Make
DocumentArrayStack
compatible with JAX. Hopefully, this should be straightforward with the computational backend agnostic but since we notice some problems with the TensorFlow backend we can expect some friction hereBattle test the whole computational backend: