poets-ai / elegy

A High Level API for Deep Learning in JAX
https://poets-ai.github.io/elegy/
MIT License
468 stars 32 forks source link

[Feature Request] Switchout imports for numpy, jax, tf.numpy, PyTorch, cupy, &etc. #161

Open SamuelMarks opened 3 years ago

SamuelMarks commented 3 years ago

Is your feature request related to a problem? Please describe. There are too many alternatives to elegy. Let's make more. If elegy could become the Keras of ML, then let's remove all the dependencies.

Describe the solution you'd like Remove all nonbuiltin imports. Specifically, centralise them all to one file, say elegy.engine, which would be simple to do, as there are only these 55 occurrences, basically all in examples:

(master#082fdc) $ rg -Ftpy 'import '|rg -v 'typing|elegy|debugpy|types|os|datatime|pathlib|shutil|yaml|datetime'
scripts/update_docs.py:from dataclasses import dataclass
scripts/update_docs.py:import jax
scripts/update_docs.py:import jinja2
examples/flax_mnist_vae_test_step.py:from tensorboardX.writer import SummaryWriter
examples/flax_mnist_vae_test_step.py:import dataget
examples/flax_mnist_vae_test_step.py:from flax import linen as nn
examples/flax_mnist_vae_test_step.py:import jax
examples/flax_mnist_vae_test_step.py:import jax.numpy as jnp
examples/flax_mnist_vae_test_step.py:import matplotlib.pyplot as plt
examples/flax_mnist_vae_test_step.py:import numpy as np
examples/flax_mnist_vae_test_step.py:import typer
examples/flax_mnist_vae_test_step.py:import optax
examples/imagenet/resnet_imagenet.py:from absl import flags, app
examples/imagenet/resnet_imagenet.py:import jax, jax.numpy as jnp
examples/imagenet/resnet_imagenet.py:import optax
examples/imagenet/resnet_imagenet.py:import tensorflow_datasets as tfds
examples/imagenet/resnet_imagenet.py:import input_pipeline
examples/imagenet/input_pipeline.py:import jax
examples/imagenet/input_pipeline.py:import tensorflow as tf
examples/imagenet/input_pipeline.py:import tensorflow_datasets as tfds
examples/jax_linear_classifier_train_step.py:import dataget
examples/jax_linear_classifier_train_step.py:import jax
examples/jax_linear_classifier_train_step.py:import jax.numpy as jnp
examples/jax_linear_classifier_train_step.py:import numpy as np
examples/jax_linear_classifier_train_step.py:import optax
examples/jax_linear_classifier_train_step.py:import typer
examples/jax_linear_classifier_test_step.py:import dataget
examples/jax_linear_classifier_test_step.py:import jax
examples/jax_linear_classifier_test_step.py:import jax.numpy as jnp
examples/jax_linear_classifier_test_step.py:import numpy as np
examples/jax_linear_classifier_test_step.py:import optax
examples/jax_linear_classifier_test_step.py:import typer
examples/jax_linear_classifier_functional.py:import dataget
examples/jax_linear_classifier_functional.py:from flax import linen
examples/jax_linear_classifier_functional.py:import jax
examples/jax_linear_classifier_functional.py:import jax.numpy as jnp
examples/jax_linear_classifier_functional.py:import numpy as np
examples/jax_linear_classifier_functional.py:import optax
examples/jax_linear_classifier_functional.py:import typer
examples/flax_mnist_vae.py:from tensorboardX.writer import SummaryWriter
examples/flax_mnist_vae.py:import dataget
examples/flax_mnist_vae.py:from flax import linen as nn
examples/flax_mnist_vae.py:import jax
examples/flax_mnist_vae.py:import jax.numpy as jnp
examples/flax_mnist_vae.py:import matplotlib.pyplot as plt
examples/flax_mnist_vae.py:import numpy as np
examples/flax_mnist_vae.py:import typer
examples/flax_mnist_vae.py:import optax
examples/flax_linear_classifier.py:import dataget
examples/flax_linear_classifier.py:from flax import linen
examples/flax_linear_classifier.py:import jax
examples/flax_linear_classifier.py:import jax.numpy as jnp
examples/flax_linear_classifier.py:import numpy as np
examples/flax_linear_classifier.py:import optax
examples/flax_linear_classifier.py:import typer

So the solution would be to have either/both a function or an environment variable set:

engine = os.environ.setdefault("ELEGY_ENGINE", "JAX")

if engine == "JAX":
    import jax.numpy as np
elif engine == "cupy":
    import cupy as np
elif engine in frozenset(("tensorflow", "tf")):
    import tensorflow.experimental.numpy as np
elif engine == "np":
    import numpy as np

Obviously this is an incomplete solution, there are a number of inconsistencies between the various numpy implementations, and the API elegy uses also goes beyond mere numpy. But moving in that direction is a good idea.

Describe alternatives you've considered Write abstract classes, and implement them in elegy-numpy, elegy-cupy, elegy-tensorflow, elegy-jax.

cgarciae commented 3 years ago

Hey @SamuelMarks!

Let me first clarify our current goals / situation:

About the possibility of creating multiple backends, my personal opinion is that Keras ultimately showed that supporting multiple backends is more of a weakness since it limits the user compared to having a single first-class backend, this is why now Keras is part of TensorFlow. Thanks to this change the need for Lambda layers disappeared, vastly improving the developer experience as you can now freely use any TensorFlow op directly.

SamuelMarks commented 3 years ago

Hey @cgarciae,

Thanks for your speedy response

I'd just like to note that the future of all these frameworks is uncertain. optax or jax.experimental.optimizers? Flax or Trax? JAX or TensorFlow?

It is conceivable that a number of these codebases will merge, become deprecated, or become completely unmaintained. elegy being not Google/Facebook/Microsoft backed increases the chance that this will be abandoned also.

What I'm proposing is a way out of it. I'm not expecting this to be easy, and certain nice primitives that JAX exposes isn't available in other frameworks. But having a bunch of TODOs around the place isn't so bad an idea, we can always port the JAX primitives at a later stage, probably exposed into it's own meta polyfill library.

Yes, Keras failed. And maybe elegy—or whoever implements this concept^—will fail also. But at some stage, a framework will succeed in coordinating most everyone into using the one consistent API. What I'm proposing is to give it a shot, IMHO it's the best chance for elegy to be around medium-to-long term.