d2l-ai / d2l-en

Interactive deep learning book with multi-framework code, math, and discussions. Adopted at 500 universities from 70 countries including Stanford, MIT, Harvard, and Cambridge.
https://D2L.ai
Other
23.24k stars 4.27k forks source link

Contributing to the jax version. #1972

Open ashutosh-dwivedi-e3502 opened 2 years ago

ashutosh-dwivedi-e3502 commented 2 years ago

I want to contribute in the effort for the jax version, I see that there's already a branch call jax with ~20 commits but is quite behind the main branch. Can you give me some more details on current jax effort and how I can contribute.

astonzhang commented 2 years ago

Thanks! See https://github.com/d2l-ai/d2l-en/issues/1825

Roy-Kid commented 2 years ago

I also want to contribute to jax version and I pull the jax branch. But I only find there are a few changes in this branch, see: image Do I pull the right branch?

ghost commented 2 years ago

Since some chapters of the book are out of sync with the master branch, which version should be implented in JAX?

For example, the code in Linear Regression from Scratch:

def synthetic_data(w, b, num_examples):  #@save
    """Generate y = Xw + b + noise."""
    X = np.random.normal(0, 1, (num_examples, len(w)))
    y = np.dot(X, w) + b
    y += np.random.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))

true_w = np.array([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

Is different from the corresponding code in linear-regression-scratch.md:

%%tab all
class LinearRegressionScratch(d2l.Module):  #@save
    def __init__(self, num_inputs, lr, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()
        if tab.selected('mxnet'):
            self.w = d2l.normal(0, sigma, (num_inputs, 1))
            self.b = d2l.zeros(1)
            self.w.attach_grad()
            self.b.attach_grad()
        if tab.selected('pytorch'):
            self.w = d2l.normal(0, sigma, (num_inputs, 1), requires_grad=True)
            self.b = d2l.zeros(1, requires_grad=True)
        if tab.selected('tensorflow'):
            w = tf.random.normal((num_inputs, 1), mean=0, stddev=0.01)
            b = tf.zeros(1)
            self.w = tf.Variable(w, trainable=True)
            self.b = tf.Variable(b, trainable=True)

In this case, should I manually modify the one from the master branch like:

        if tab.selected('tensorflow'):
            w = tf.random.normal((num_inputs, 1), mean=0, stddev=0.01)
            b = tf.zeros(1)
            self.w = tf.Variable(w, trainable=True)
            self.b = tf.Variable(b, trainable=True)
+      if tab.selected('jax'):
+           key = random.PRNGKey(42)
+           self.w = random.normal(key, (num_inputs, 1)) * 0.01 + 0
+           self.b = jnp.zeros(1)

Or perhaps making the D2L module work with JAX at first would be better.

What do you suggest? @astonzhang

atgctg commented 2 years ago

Hey @AnirudhDagar, could you please merge master into the jax branch?

I'm working on a JAX version for the v1 release of the book and have a few chapters ready for review. I just don't want to open any PRs until the jax branch is up to date with v1.

Or should I give it a go?

AnirudhDagar commented 2 years ago

Hi @atgctg, thanks for your interest in JAX port. I've synced the branch. I'm almost done with chapter 3, didn't know that you were working on it as well. But feel free to raise a PR for the chapter. In the future to avoid duplication and two people working on the same thing, let's move ahead with a tracker.

atgctg commented 2 years ago

Thanks! I'll open a PR then.

It would be great to standardize the API first, so other chapters can build on that. I would love to see your approach as well.

AnirudhDagar commented 2 years ago

Would you be so kind as to wait until this evening before sending a PR? I'd like to fix some CI issues first, which might affect Jax development. I'll let you know once that is fixed. Thanks! :)

atgctg commented 2 years ago

Of course!