google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.14k stars 647 forks source link

PJit example #1763

Open mattiasmar opened 2 years ago

mattiasmar commented 2 years ago

Description of the model to be implemented

ResNet18

Dataset the model could be trained on

tensorflow_datasets mock_data, e.g.

with tfds.testing.mock_data(num_examples=128):
    ds = tfds.load('imagenette', split='train')

Specific points to consider

Reference implementations in other frameworks

https://github.com/google/flax/blob/main/examples/imagenet/imagenet.ipynb

marcvanzee commented 2 years ago

Hi @mattiasmar, we are considering of adding the WMT example with PJIT instead, since many people working on large language models are interested in it (and the models in language are usually much bigger than in vision). Would that work for you as well?

mattiasmar commented 2 years ago

Yes, WMT would be a good example. A request only: Could you enter the pjit as high up in the code as possible? As a user I want to pjit as a large part as possible of my model/program. Ideally I would like to use pjit only once in my program.