google / objax

Apache License 2.0
769 stars 77 forks source link

Could you outline how to write a simplest RNN Module? #31

Open jli05 opened 4 years ago

jli05 commented 4 years ago

I'm looking for write a basic RNN that does f(Ax+b) at each time step.

What would be the best way to go about it? Could you outline some code to give an idea?

Can one apply JIT over the entire (unrolled) network for training/inference?

kihyuks commented 4 years ago

Hi, would you take a look at our code example (https://github.com/google/objax/blob/master/examples/rnn/shakespeare.py) and see if this resolves your concern?

NTT123 commented 4 years ago

@jli05, I would recommend you to look at dm-haiku recurrent module (note that the syntax is a bit difference):

https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/recurrent.py

It includes VanillaRNN module, and the two important functions: static_unroll, dynamic_unroll that will help you to JIT the whole sequence in training.

You can also see how to unroll sampling process at https://github.com/deepmind/dm-haiku/blob/master/examples/rnn/train.py#L98

If you want to know more about how to unroll a loop, you can also look at the jax flow control ops (eg., scan and fori_loop) at https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators

jli05 commented 4 years ago

Thanks. Yes the references helped.

The Shakespeare code trains slowly on MacBook CPU. It takes >5 minutes to train 10 epochs. I killed it before it ran to the end.

I also found one can pickle dump a jit-tered network but cannot pickle load it. Maybe jittered code was never intended for pickling.

kihyuks commented 4 years ago

@aterzis-google Any help would be appreciated !

aterzis-google commented 4 years ago

I am working on this and will have an update by Monday (Sep 28).

jli05 commented 4 years ago

If you could share any lessons/findings about the JAX framework itself through your work it'd be much appreciated.

JAX is apparently faster than TensorFlow. But it's quite young. I read some internal code of JAX and the Issues on the JAX repo. Still undecided whether to rewrite a TensorFlow application in JAX.

aterzis-google commented 4 years ago

@jli05 I can say that many projects in Google Research use JAX and the frameworks built on top of it. If you can say a little more about your project then maybe I can give some more direct guidance.

I also wanted to ask you about your original question about speed of training.

Did you consider using Colab which has support for GPUs? Also, have you seen the section about saving and loading model weights? about https://objax.readthedocs.io/en/latest/advanced/io.html#saving-and-loading-model-weights

jli05 commented 4 years ago

Yes I read about that documentation.

The model we're making is relatively small but need be distributed out and run many times for training and inference. What makes JAX and its derived frameworks attractive is that it is compact, small size, relatively quick load time, has no minimal hardware requirement for development. It sticks to its job.

jax package is about 240k, jaxlib is 40M. If jaxlib becomes optional that'd be more fantastic. I didn't give priority to exploring a 200M solution based on other frameworks, if it can be done by a 200k one.

Currently we're just trying to make things right on CPU. We were concerned about frequent data I/O so didn't start to benchmark on more advanced hardware. Just getting basic things right at this stage as anything will be parallelised on a bigger scale.

aterzis-google commented 3 years ago

@jli05 also see https://github.com/google/objax/pull/97 for ongoing work to refactor RNN