Open jli05 opened 4 years ago
Hi, would you take a look at our code example (https://github.com/google/objax/blob/master/examples/rnn/shakespeare.py) and see if this resolves your concern?
@jli05, I would recommend you to look at dm-haiku
recurrent module (note that the syntax is a bit difference):
https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/recurrent.py
It includes VanillaRNN
module, and the two important functions: static_unroll, dynamic_unroll
that will help you to JIT the whole sequence in training.
You can also see how to unroll sampling process at https://github.com/deepmind/dm-haiku/blob/master/examples/rnn/train.py#L98
If you want to know more about how to unroll a loop, you can also look at the jax flow control ops (eg., scan
and fori_loop
) at
https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators
Thanks. Yes the references helped.
The Shakespeare code trains slowly on MacBook CPU. It takes >5 minutes to train 10 epochs. I killed it before it ran to the end.
I also found one can pickle dump a jit-tered network but cannot pickle load it. Maybe jittered code was never intended for pickling.
@aterzis-google Any help would be appreciated !
I am working on this and will have an update by Monday (Sep 28).
If you could share any lessons/findings about the JAX framework itself through your work it'd be much appreciated.
JAX is apparently faster than TensorFlow. But it's quite young. I read some internal code of JAX and the Issues
on the JAX repo. Still undecided whether to rewrite a TensorFlow application in JAX.
@jli05 I can say that many projects in Google Research use JAX and the frameworks built on top of it. If you can say a little more about your project then maybe I can give some more direct guidance.
I also wanted to ask you about your original question about speed of training.
Did you consider using Colab which has support for GPUs? Also, have you seen the section about saving and loading model weights? about https://objax.readthedocs.io/en/latest/advanced/io.html#saving-and-loading-model-weights
Yes I read about that documentation.
The model we're making is relatively small but need be distributed out and run many times for training and inference. What makes JAX and its derived frameworks attractive is that it is compact, small size, relatively quick load time, has no minimal hardware requirement for development. It sticks to its job.
jax
package is about 240k, jaxlib
is 40M. If jaxlib
becomes optional that'd be more fantastic. I didn't give priority to exploring a 200M solution based on other frameworks, if it can be done by a 200k one.
Currently we're just trying to make things right on CPU. We were concerned about frequent data I/O so didn't start to benchmark on more advanced hardware. Just getting basic things right at this stage as anything will be parallelised on a bigger scale.
@jli05 also see https://github.com/google/objax/pull/97 for ongoing work to refactor RNN
I'm looking for write a basic RNN that does f(Ax+b) at each time step.
What would be the best way to go about it? Could you outline some code to give an idea?
Can one apply JIT over the entire (unrolled) network for training/inference?