Open ashutosh-dwivedi-e3502 opened 2 years ago
Thanks! See https://github.com/d2l-ai/d2l-en/issues/1825
I also want to contribute to jax version and I pull the jax branch. But I only find there are a few changes in this branch, see: Do I pull the right branch?
Since some chapters of the book are out of sync with the master branch, which version should be implented in JAX?
For example, the code in Linear Regression from Scratch:
def synthetic_data(w, b, num_examples): #@save
"""Generate y = Xw + b + noise."""
X = np.random.normal(0, 1, (num_examples, len(w)))
y = np.dot(X, w) + b
y += np.random.normal(0, 0.01, y.shape)
return X, y.reshape((-1, 1))
true_w = np.array([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
Is different from the corresponding code in linear-regression-scratch.md
:
%%tab all
class LinearRegressionScratch(d2l.Module): #@save
def __init__(self, num_inputs, lr, sigma=0.01):
super().__init__()
self.save_hyperparameters()
if tab.selected('mxnet'):
self.w = d2l.normal(0, sigma, (num_inputs, 1))
self.b = d2l.zeros(1)
self.w.attach_grad()
self.b.attach_grad()
if tab.selected('pytorch'):
self.w = d2l.normal(0, sigma, (num_inputs, 1), requires_grad=True)
self.b = d2l.zeros(1, requires_grad=True)
if tab.selected('tensorflow'):
w = tf.random.normal((num_inputs, 1), mean=0, stddev=0.01)
b = tf.zeros(1)
self.w = tf.Variable(w, trainable=True)
self.b = tf.Variable(b, trainable=True)
In this case, should I manually modify the one from the master branch like:
if tab.selected('tensorflow'):
w = tf.random.normal((num_inputs, 1), mean=0, stddev=0.01)
b = tf.zeros(1)
self.w = tf.Variable(w, trainable=True)
self.b = tf.Variable(b, trainable=True)
+ if tab.selected('jax'):
+ key = random.PRNGKey(42)
+ self.w = random.normal(key, (num_inputs, 1)) * 0.01 + 0
+ self.b = jnp.zeros(1)
Or perhaps making the D2L module work with JAX at first would be better.
What do you suggest? @astonzhang
Hey @AnirudhDagar, could you please merge master
into the jax
branch?
I'm working on a JAX version for the v1 release of the book and have a few chapters ready for review. I just don't want to open any PRs until the jax
branch is up to date with v1.
Or should I give it a go?
Hi @atgctg, thanks for your interest in JAX port. I've synced the branch. I'm almost done with chapter 3, didn't know that you were working on it as well. But feel free to raise a PR for the chapter. In the future to avoid duplication and two people working on the same thing, let's move ahead with a tracker.
Thanks! I'll open a PR then.
It would be great to standardize the API first, so other chapters can build on that. I would love to see your approach as well.
Would you be so kind as to wait until this evening before sending a PR? I'd like to fix some CI issues first, which might affect Jax development. I'll let you know once that is fixed. Thanks! :)
Of course!
I want to contribute in the effort for the jax version, I see that there's already a branch call jax with ~20 commits but is quite behind the main branch. Can you give me some more details on current jax effort and how I can contribute.