Closed gerkone closed 11 months ago
Thanks! I believe we are also seeing some runtime differences in some other applications. We will get back to you soon.
The slowdown in the backward pass and in the compilation time between 0.17.3 and 0.19.1 is due to the changes in IrrepsArray
introduced in 0.19.0.
Before 0.19.0, IrrepsArray was storing the data in both a contiguous array and a list of chunks (for instance 16x0e + 32x1o
has two chunks so we would have a contiguous array (112,)
and a list of arrays [(16, 1), (32, 3)]
). This was somehow helping the compiler for the backward pass.
Since 0.19.0, IrrepsArray only has the contiguous array and the chunks are recreated on the fly. For some reason jit does not like that.
I will try to investigate further.
I found something relevant. In your test you are taking the derivative wrt x.
# backwards
@jax.jit
def grad_fn(x):
def loss_fn(x):
return jnp.mean(apply(w, x))
return jax.grad(loss_fn)(x)
If you take a derivative wrt to w, at least on my cpu, the new versions are not slower. Can you try on your gpu?
# backwards
@jax.jit
def grad_fn(w, x):
def loss_fn(w, x):
return jnp.mean(apply(w, x))
return jax.grad(loss_fn)(w, x)
@gerkone please try now with version 0.20.0
and 0.20.1
Hey @mariogeiger thanks for your replies. I ran a couple of things, here's what I found
Linear
+ tensor_product
(backward wrt x)For completeness here is the updated plot with version 0.20.0
. It definitely looks better than before, but still slower than 0.17.3
.
Linear
+ tensor_product
(backward wrt params)I also tried taking the derivative wrt the weights, and indeed runtimes are similar through all versions. Jit times are still very high though; here is the plot and table (with derivatives wrt params as well).
The results are unexpected to me, and even taking your explanation into account I can't really find a reason.
L=1 | L=3 | L=5 | |||||||
---|---|---|---|---|---|---|---|---|---|
Frwd | Bkwd | Jit | Frwd | Bkwd | Jit | Frwd | Bkwd | Jit | |
0.17.3 | 8.26 | 10.10 | 211.15 | 20.42 | 26.68 | 277.30 | 45.92 | 62.44 | 421.97 |
0.19.1 | 7.83 | 9.48 | 342.22 | 21.22 | 27.28 | 1949.29 | 49.29 | 55.73 | 15627.83 |
0.19.2 | 8.86 | 10.13 | 346.49 | 16.19 | 24.20 | 2849.34 | 22.59 | 30.91 | 27257.30 |
0.20.0 | 8.47 | 12.93 | 223.31 | 14.90 | 22.56 | 385.46 | 22.88 | 28.62 | 882.78 |
tensor_product
onlyTensor product only, without a linear layer. No particular difference can be noticed here.
Since this is what made me notice this originally I also wanted to see the differences in practice. Here is again the bar plot and table on full segnn models (derivative wrt x)
L=1 SEGNN | L=2 SEGNN | L=3 SEGNN | |||||||
---|---|---|---|---|---|---|---|---|---|
Fwrd | Bkwd | Jit | Fwrd | Bkwd | Jit | Fwrd | Bkwd | Jit | |
0.17.3 | 207.98 | 499.78 | 4343.33 | 495.22 | 1096.72 | 9955.54 | 937.68 | 2297.58 | 17534.29 |
0.19.1 | 183.45 | 787.24 | 10061.65 | 481.54 | 2321.90 | 28145.21 | 924.88 | 7307.63 | 67915.30 |
0.19.2 | 436.03 | 1165.99 | 17921.59 | 615.51 | 2520.09 | 35293.16 | 1116.62 | 7734.88 | 100226.55 |
0.20.0 | 469.82 | 1011.10 | 10210.34 | 614.72 | 1818.22 | 18612.05 | 1214.80 | 3473.40 | 39235.92 |
Backward (and jit) times get worse after version 0.17.3
(they blow up at higher orders) as expected, but only improve in terms of scaling to higher Ls with version 0.20.0
. Additionally, forward times seem to also have gotten worse with newer versions.
This is more or less in line with what I got with the test script (it better be; in the end segnn is just a stack of parametrized tensor products).
I agree that looking throughout the latest releases the major change is the lazy chunks in the IrrepArray
, and reverting it with both contiguous and chunked representation does actually help on higher orders (probably because the contiguous array is larger/has more chunks). On the other hand (assuming I didn't make any mistakes) it looks like there is something else going on, maybe unrelated to the IrrepsArray
, but I can't find what it could be.
I also have a couple of ideas of benchmarking IrrepsArray
creation/rechunking and operations but I still haven't gotten to it. In case I manage something I'll let you know.
I hope this is somehow useful, and thanks for the latest release which already improved runtimes significantly.
Thanks a lot!
0.20.1
in the benchmark? several optimizations have been added back in this releaseLinear + tensor_product (backward wrt x)
is a good proxy for SEGNN backward
?Sure, here's the updated plot/table with version 0.20.1
, and it indeed looks like you are right! Now it's looking much better already. Still, a couple of this I noticed are
0.19.2
, while on larger Ls it's even faster than 0.17.3
.I didn't include it at first because the latest version is 0.20.0
on PyPi, so I had to install it with pip install git+https://github.com/e3nn/e3nn-jax.git@0.20.1
.
L=1 | L=3 | L=5 | |||||||
---|---|---|---|---|---|---|---|---|---|
Frwd | Bkwd | Jit | Frwd | Bkwd | Jit | Frwd | Bkwd | Jit | |
0.17.3 | 8.76 | 12.54 | 281.73 | 28.12 | 40.30 | 646.79 | 48.02 | 99.25 | 1192.13 |
0.19.1 | 8.68 | 14.34 | 616.73 | 23.19 | 211.03 | 4865.74 | 43.72 | 667.66 | 30524.66 |
0.19.2 | 24.08 | 28.08 | 2387.04 | 28.18 | 324.11 | 9087.49 | 40.97 | 959.80 | 56982.16 |
0.20.0 | 24.51 | 16.37 | 2054.15 | 28.03 | 156.77 | 4349.48 | 34.83 | 459.85 | 9607.55 |
0.20.1 | 23.32 | 12.87 | 1694.26 | 29.97 | 32.96 | 2709.91 | 31.10 | 50.55 | 3579.64 |
As for the other question I would say it is a good approximation: segnn is of course more involved but runtimes should be closely related, the main difference being IrrepsArray
operations here and there in segnn. Most of the compute is spent in the tensor products anyway.
To me it looks like backward runtime is back to normal (or better). I'm leaving the issue open in case you want to address the L=1 forward thing. Thanks again for your work.
@gerkone please reopen this issue in case
Hey. The
tensor_product
implementation has changed substantially from0.17.3
to0.19.3
, which made it simpler and much faster in inference. This does not seem to be the case for backpropagation (especially at high Ls). I noticed this while training segnn withL=3
on a problem that used to take hours withe3nn-jax 0.17.3
, but with0.19.3
did not finish in days.Model init and backward jitting also always take a lot of time, which is not a big deal in practice but could mean something.
On the other hand, differentiating
tensor_product
only (without linear layer) is about the same/faster in the newer versions, which is unexpected.Detailed results in table. Forward and backward are 100 (compiled) passes on single Linear +
tensor_product
layer, with the spherical harmonics at the respective order as input and1x0e + 1x1o
as output. Jit time refers to the backward pass jitting.Reproduction script