e3nn / e3nn-jax

jax library for E3 Equivariant Neural Networks
Apache License 2.0
182 stars 18 forks source link

Backward pass runtime degradation (Linear + `tensor_product`) in the latest versions #38

Closed gerkone closed 11 months ago

gerkone commented 1 year ago

Hey. The tensor_product implementation has changed substantially from 0.17.3 to 0.19.3, which made it simpler and much faster in inference. This does not seem to be the case for backpropagation (especially at high Ls). I noticed this while training segnn with L=3 on a problem that used to take hours with e3nn-jax 0.17.3, but with 0.19.3 did not finish in days.

Model init and backward jitting also always take a lot of time, which is not a big deal in practice but could mean something.

e3nn-jax version vs Backprop time

On the other hand, differentiating tensor_product only (without linear layer) is about the same/faster in the newer versions, which is unexpected.

e3nn-jax version vs Backprop time (tp only)


Detailed results in table. Forward and backward are 100 (compiled) passes on single Linear + tensor_product layer, with the spherical harmonics at the respective order as input and 1x0e + 1x1o as output. Jit time refers to the backward pass jitting.

L=0 L=3 L=5
Frwd Bkwd Jit Frwd Bkwd Jit Frwd Bkwd Jit
0.17.3 16.63 10.70 397.70 30.32 27.91 751.56 57.19 80.14 1232.86
0.19.1 15.14 22.61 747.90 32.16 500.49 5220.34 52.06 1275.34 34265.00
0.19.2 16.19 35.12 1123.24 26.00 499.58 8296.36 43.45 2113.12 60620.00

Reproduction script

import os
import jax
import jax.numpy as jnp 
import time
import haiku as hk
import warnings
from multiprocessing import Process, Queue

warnings.filterwarnings("ignore")

def tp_test(queue):
    import e3nn_jax as e3nn

    def run(tp_fn, ir, bs=100, reruns=100):    
        x = e3nn.normal(ir, key=jax.random.PRNGKey(0), leading_shape=(bs,))

        st = time.perf_counter_ns()
        tp = hk.without_apply_rng(hk.transform(tp_fn))
        w = tp.init(jax.random.PRNGKey(0), x)
        apply = jax.jit(tp.apply)
        jax.block_until_ready(apply(w, x))
        init_end = (time.perf_counter_ns() - st) / 1e6

        st = time.perf_counter_ns()
        for _ in range(reruns):
            jax.block_until_ready(apply(w, x))
        forward_end = (time.perf_counter_ns() - st) / 1e6

        # backwards
        @jax.jit
        def grad_fn(x):
            def loss_fn(x):
                return jnp.mean(apply(w, x))
            return jax.grad(loss_fn)(x)

        st = time.perf_counter_ns()
        jax.block_until_ready(grad_fn(x))
        jit_end = (time.perf_counter_ns() - st) / 1e6

        st = time.perf_counter_ns()
        for _ in range(reruns):
            jax.block_until_ready(grad_fn(x))
        backward_end = (time.perf_counter_ns() - st) / 1e6

        return forward_end, backward_end, init_end, jit_end

    print(f"Version {e3nn.__version__}")

    results = {}

    for L in [1, 3, 5]:
        ir_L = e3nn.Irreps.spherical_harmonics(L) * 8
        def only_tp(x):
            return e3nn.tensor_product(x, x, filter_ir_out="1x0e+1x1o").array

        frwd_time, bkwd_time, init_time, jit_time = run(only_tp, ir_L)
        print(
            f"[L={L} TP only]: forward={frwd_time:.2f}ms (init={init_time:.2f}ms) - "
            f"backward={bkwd_time:.2f}ms (jit={jit_time:.2f}ms)"
        )

        def linear_tp(x):
            return e3nn.haiku.Linear("1x0e+1x1o")(e3nn.tensor_product(x, x)).array

        frwd_time, bkwd_time, init_time, jit_time = run(linear_tp, ir_L)
        print(
            f"[L={L} Linear TP]: forward={frwd_time:.2f}ms (init={init_time:.2f}ms) - "
            f"backward={bkwd_time:.2f}ms (jit={jit_time:.2f}ms)"
        )

        results[f"L={L}"] = bkwd_time

    queue.put({e3nn.__version__: results})

if __name__ == "__main__":

    queue = Queue()

    os.system("pip install e3nn-jax==0.17.3 >/dev/null 2>&1")
    p = Process(target=tp_test, args=(queue,))
    p.start()
    p.join()

    print("---")

    os.system("pip install e3nn-jax==0.19.1 >/dev/null 2>&1")
    p = Process(target=tp_test, args=(queue,))
    p.start()
    p.join()

    print("---")

    os.system("pip install e3nn-jax==0.19.2 >/dev/null 2>&1")
    p = Process(target=tp_test, args=(queue,))
    p.start()
    p.join()

    try:
        import matplotlib.pyplot as plt
        import pandas as pd
        results = {}
        while not queue.empty():
            results.update(queue.get())
        results = pd.DataFrame(results)
        results.plot.bar()
        plt.title("Backward time (Linear + tensor_product only)")
        plt.ylabel("time [ms]")
        plt.savefig("e3nn_backward.png")

    except ImportError:
        pass
ameya98 commented 1 year ago

Thanks! I believe we are also seeing some runtime differences in some other applications. We will get back to you soon.

mariogeiger commented 1 year ago

The slowdown in the backward pass and in the compilation time between 0.17.3 and 0.19.1 is due to the changes in IrrepsArray introduced in 0.19.0.

Before 0.19.0, IrrepsArray was storing the data in both a contiguous array and a list of chunks (for instance 16x0e + 32x1o has two chunks so we would have a contiguous array (112,) and a list of arrays [(16, 1), (32, 3)]). This was somehow helping the compiler for the backward pass.

Since 0.19.0, IrrepsArray only has the contiguous array and the chunks are recreated on the fly. For some reason jit does not like that.

I will try to investigate further.

mariogeiger commented 1 year ago

I found something relevant. In your test you are taking the derivative wrt x.

# backwards
@jax.jit
def grad_fn(x):
    def loss_fn(x):
        return jnp.mean(apply(w, x))
    return jax.grad(loss_fn)(x)

If you take a derivative wrt to w, at least on my cpu, the new versions are not slower. Can you try on your gpu?

# backwards
@jax.jit
def grad_fn(w, x):
    def loss_fn(w, x):
        return jnp.mean(apply(w, x))
    return jax.grad(loss_fn)(w, x)
mariogeiger commented 1 year ago

@gerkone please try now with version 0.20.0 and 0.20.1

gerkone commented 1 year ago

Hey @mariogeiger thanks for your replies. I ran a couple of things, here's what I found

Script posted in first comment

Linear + tensor_product (backward wrt x)

For completeness here is the updated plot with version 0.20.0. It definitely looks better than before, but still slower than 0.17.3.

image

Linear + tensor_product (backward wrt params)

I also tried taking the derivative wrt the weights, and indeed runtimes are similar through all versions. Jit times are still very high though; here is the plot and table (with derivatives wrt params as well).

The results are unexpected to me, and even taking your explanation into account I can't really find a reason.

image

L=1 L=3 L=5
Frwd Bkwd Jit Frwd Bkwd Jit Frwd Bkwd Jit
0.17.3 8.26 10.10 211.15 20.42 26.68 277.30 45.92 62.44 421.97
0.19.1 7.83 9.48 342.22 21.22 27.28 1949.29 49.29 55.73 15627.83
0.19.2 8.86 10.13 346.49 16.19 24.20 2849.34 22.59 30.91 27257.30
0.20.0 8.47 12.93 223.31 14.90 22.56 385.46 22.88 28.62 882.78

tensor_product only

Tensor product only, without a linear layer. No particular difference can be noticed here.

image

SEGNN

Since this is what made me notice this originally I also wanted to see the differences in practice. Here is again the bar plot and table on full segnn models (derivative wrt x)

image

L=1 SEGNN L=2 SEGNN L=3 SEGNN
Fwrd Bkwd Jit Fwrd Bkwd Jit Fwrd Bkwd Jit
0.17.3 207.98 499.78 4343.33 495.22 1096.72 9955.54 937.68 2297.58 17534.29
0.19.1 183.45 787.24 10061.65 481.54 2321.90 28145.21 924.88 7307.63 67915.30
0.19.2 436.03 1165.99 17921.59 615.51 2520.09 35293.16 1116.62 7734.88 100226.55
0.20.0 469.82 1011.10 10210.34 614.72 1818.22 18612.05 1214.80 3473.40 39235.92

Backward (and jit) times get worse after version 0.17.3 (they blow up at higher orders) as expected, but only improve in terms of scaling to higher Ls with version 0.20.0. Additionally, forward times seem to also have gotten worse with newer versions.

This is more or less in line with what I got with the test script (it better be; in the end segnn is just a stack of parametrized tensor products).


I agree that looking throughout the latest releases the major change is the lazy chunks in the IrrepArray, and reverting it with both contiguous and chunked representation does actually help on higher orders (probably because the contiguous array is larger/has more chunks). On the other hand (assuming I didn't make any mistakes) it looks like there is something else going on, maybe unrelated to the IrrepsArray, but I can't find what it could be.

I also have a couple of ideas of benchmarking IrrepsArray creation/rechunking and operations but I still haven't gotten to it. In case I manage something I'll let you know.

I hope this is somehow useful, and thanks for the latest release which already improved runtimes significantly.

mariogeiger commented 1 year ago

Thanks a lot!

gerkone commented 1 year ago

Sure, here's the updated plot/table with version 0.20.1, and it indeed looks like you are right! Now it's looking much better already. Still, a couple of this I noticed are

I didn't include it at first because the latest version is 0.20.0 on PyPi, so I had to install it with pip install git+https://github.com/e3nn/e3nn-jax.git@0.20.1.

image

L=1 L=3 L=5
Frwd Bkwd Jit Frwd Bkwd Jit Frwd Bkwd Jit
0.17.3 8.76 12.54 281.73 28.12 40.30 646.79 48.02 99.25 1192.13
0.19.1 8.68 14.34 616.73 23.19 211.03 4865.74 43.72 667.66 30524.66
0.19.2 24.08 28.08 2387.04 28.18 324.11 9087.49 40.97 959.80 56982.16
0.20.0 24.51 16.37 2054.15 28.03 156.77 4349.48 34.83 459.85 9607.55
0.20.1 23.32 12.87 1694.26 29.97 32.96 2709.91 31.10 50.55 3579.64

As for the other question I would say it is a good approximation: segnn is of course more involved but runtimes should be closely related, the main difference being IrrepsArray operations here and there in segnn. Most of the compute is spent in the tensor products anyway.

To me it looks like backward runtime is back to normal (or better). I'm leaving the issue open in case you want to address the L=1 forward thing. Thanks again for your work.

mariogeiger commented 11 months ago

@gerkone please reopen this issue in case