jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.97k stars 2.75k forks source link

Computing Jacobian w.r.t. parameters used in lax.associative_scan is unexpectedly slow #19498

Open markschoene opened 7 months ago

markschoene commented 7 months ago

Description

Thanks for considering this issue. My research requires me to compute the jacobian of a recurrent neural network with respect to a set of quantities. The minimal example below considers a linear recurrence $xt = \lambda \odot x{t-1} + B u_t$ as found in recent deep state-space models such as S5. Here, $x_t, \lambda \in\mathbb{C}^n, u_t\in\mathbb{R}^m, B\in\mathbb{C}^{n\times m}$.

Problem Description

Since recursion relations can be formulated as associative operators, they can be parallelized using lax.associative_scan.

def binary_operator(element_i, element_j):
    # Binary operator for parallel scan of linear recurrence.
    a_i, bu_i = element_i
    a_j, bu_j = element_j
    return a_j * a_i, a_j * bu_i + bu_j

My issue arises when using AD to compute the following quantities

I would expect these operations to

When measuring the compute time on A100 (40GB), I find that particularly the derivative $\frac{\partial x_t}{\partial \lambda}$ is much slower than the other ones. The compute time of all derivatives increases linearly with sequence length (measured many different specifications, but the example below does it from $t=32,\dots,512$). Yet, the derivative $\frac{\partial x_t}{\partial \lambda}$ has much steeper slope as shown in the image below.

The parameter $\lambda$ is fed to the binary_operator as a_i, a_j. If I replace the return statement with return jax.numpy.ones_like(a_j) * jax.numpy.ones_like(a_i), a_j * bu_i + bu_j such that AD doesn't need to trace the value of $\lambda$ through the first argument of the binary_operator, the severe differences in compute times vanish.

For completeness, the example below contains forward and backward mode AD results for networks with 1 and 2 layers.

I am now wondering

  1. Why is this particular jacobian so slow in contrast to the others?
  2. Why it appears not to parallelize the computation along sequence length despite applying associative_scan (guessing from linearly increasing compute times with large slope)

Figure: Note the slope of the linear fits in brackets is an order of magnitude larger for the jacobian (scan), i.e. w.r.t. the parameter iteratively used in the scan: $\lambda$ jacobian_measurements

Code to reproduce the measurement

The code below also allows to measure the same program implemented with lax.scan and a simple python for loop. These are commented out below. Just uncomment them to also measure these quantities. Compile times for the python for loop will be quite significant for larger sequence length.

import jax
import jax.numpy as jnp
from timeit import timeit
from time import time
from functools import partial
import pandas as pd

def scan(weight, input_sequence, init_state, unroll=1):
    # input sequence has shape (T, state_size)
    def scan_fn(h_t, x_t):
        h_t = weight * h_t + x_t
        return h_t, h_t
    carry, ys = jax.lax.scan(scan_fn, init=init_state, xs=input_sequence, unroll=unroll)
    return carry, ys

def for_loop(weight, input_sequence, init_state):
    # input sequence has shape (T, state_size)
    h_t = init_state
    ys = []
    for x_t in input_sequence:
        h_t = weight * h_t + x_t
        ys.append(h_t)
    return h_t, jnp.stack(ys)

def binary_operator(element_i, element_j):
    # Binary operator for parallel scan of linear recurrence.
    a_i, bu_i = element_i
    a_j, bu_j = element_j
    return a_j * a_i, a_j * bu_i + bu_j
    #return jnp.ones_like(a_i) * jnp.ones_like(a_j), a_j * bu_i + bu_j

def associative_scan(weight, input_sequence, init_state):
    sequence_length, state_size = input_sequence.shape
    W_elements = jnp.repeat(weight[None, ...], sequence_length + 1, axis=0)
    elements = (W_elements, jnp.concatenate([init_state[None, ...], input_sequence], axis=0))
    _, ys = jax.lax.associative_scan(binary_operator, elements, axis=0)
    return ys[-1], ys[1:]

def single_layer_model(fun, params, params_scan, input_sequence, init_state):
    x = jnp.einsum('ij,lj->li', params['weight1'], input_sequence) + params_scan['bias1']
    return fun(params_scan['lambda1'], x, init_state)

def two_layer_model(fun, params, params_scan, input_sequence, init_state):
    state1, state2 = init_state
    lambda1, lambda2 = params_scan['lambda1'], params_scan['lambda2']
    x = jnp.einsum('ij,lj->li', params['weight1'], input_sequence) + params_scan['bias1']
    state1, x = fun(lambda1, x, state1)
    x = jnp.einsum('ij,lj->li', params['weight2'], x) + params_scan['bias2']
    state2, x = fun(lambda2, x, state2)
    return state2, x

def jacobian(fun, mode='backward'):
    if mode == 'backward':
        return jax.jacrev(fun)
    elif mode == 'forward':
        return jax.jacfwd(fun)
    else:
        raise NotImplementedError

def state_jacobian_fun(fun, params, params_scan, input_sequence, init_state, mode):
    return jacobian(lambda h: fun(params, params_scan, input_sequence, h)[0], mode=mode)(init_state)

def param_jacobian_fun(fun, params, params_scan, input_sequence, init_state, mode):
    return jacobian(lambda p: fun(p, params_scan, input_sequence, init_state)[0], mode=mode)(params)

def param_scan_jacobian_fun(fun, params, params_scan, input_sequence, init_state, mode):
    return jacobian(lambda p: fun(params, p, input_sequence, init_state)[0], mode=mode)(params_scan)

def input_jacobian_fun(fun, params, params_scan, input_sequence, init_state, mode):
    return jacobian(lambda x: fun(params, params_scan, x, init_state)[0], mode=mode)(input_sequence)

def measure_base(fun, params, params_scan, input_sequence, init_state, num_iterations=100):
    # jit the function over a batch jit->vmap->fun
    jit_fn = jax.jit(jax.vmap(fun, (None, None, 0, 0)))

    # measure compilation time
    start = time()
    out = jax.block_until_ready(jit_fn(params, params_scan, input_sequence, init_state))
    t_compile = time() - start

    # measure run time
    t_run = timeit(lambda: jax.block_until_ready(jit_fn(params, params_scan, input_sequence, init_state)), number=num_iterations) / num_iterations * 1000

    return t_run, t_compile

def main(T, state_size, batch_size):
    key = jax.random.PRNGKey(0)
    x_key, init_key, w1_key, w2_key, l1_key, l2_key, l3_key = jax.random.split(key, 7)
    xs = jax.random.normal(x_key, (batch_size, T, state_size))
    params = {
        'weight1': jax.random.normal(w1_key, (state_size, state_size)),
        'weight2': jax.random.normal(w2_key, (state_size, state_size)),
    }
    params_scan = {
        'lambda1': jax.random.uniform(l1_key, (state_size,)),
        'lambda2': jax.random.uniform(l2_key, (state_size,)),
        'bias1': jax.random.normal(w1_key, (state_size,)),
        'bias2': jax.random.normal(w2_key, (state_size,))
    }

    experiments = {
        #'for loop': for_loop,
        #'scan  (1)': partial(scan, unroll=1),
        #'scan  (8)': partial(scan, unroll=8),
        #'scan (64)': partial(scan, unroll=64),
        'associative scan': associative_scan
    }
    results = []
    for mode in ['backward', 'forward']:
        for tag1, base_fun in experiments.items():
            for tag2, forward_fun, init in zip(
                [1, 2],
                [single_layer_model, two_layer_model],
                [jax.random.normal(init_key, (batch_size, state_size)),
                 (jax.random.normal(init_key, (batch_size, state_size)),
                  jax.random.normal(init_key, (batch_size, state_size)))]
            ):
                measure = partial(
                    measure_base,
                    params=params,
                    params_scan=params_scan,
                    input_sequence=xs,
                    init_state=init
                )
                model_fun = partial(forward_fun, base_fun)

                forward, _ = measure(model_fun)
                state_jac, _ = measure(partial(state_jacobian_fun, model_fun, mode=mode))
                param_jac, _ = measure(partial(param_jacobian_fun, model_fun, mode='backward'))
                param_scan_jac, _ = measure(partial(param_scan_jacobian_fun, model_fun, mode=mode))
                inout_jac, _ = measure(partial(input_jacobian_fun, model_fun, mode='backward'))
                results.append([tag1, mode, tag2, T, forward, state_jac, param_jac, param_scan_jac, inout_jac])

    df = pd.DataFrame(data=results, columns=[
        'func', 'mode', 'layers', 'steps', 'forward [ms]',
        'jacobian (state) [ms]', 'jacobian (params) [ms]', 'jacobian (scan) [ms]',
        'input-output [ms]'])
    df = df.sort_values(by=['layers', 'func', 'mode']).reset_index(drop=True)
    return df

if __name__ == '__main__':
    pd.options.display.float_format = '{:,.2f}'.format

    data = []
    for T in range(1, 17):
        T = T * 32
        print(40 * "*")
        print("Evaluating T =", T)
        df = main(T, state_size=64, batch_size=16)
        data.append(df)
    df = pd.concat(data, ignore_index=True)
    df = df.sort_values(by=['layers', 'func', 'mode']).reset_index(drop=True)
    df.to_csv('measurements/jacobian_minimal_example.csv')

What jax/jaxlib version are you using?

jax v0.4.20 jaxlib v0.4.20+cuda12.cudnn89

Which accelerator(s) are you using?

GPU

Additional system info?

Python 3.10.8, Linux

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-40GB          On  | 00000000:0B:00.0 Off |                    0 |
| N/A   43C    P0              65W / 400W |  39830MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          On  | 00000000:11:00.0 Off |                    0 |
| N/A   43C    P0              64W / 400W |  19151MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM4-40GB          On  | 00000000:3B:00.0 Off |                    0 |
| N/A   44C    P0              64W / 400W |  19151MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM4-40GB          On  | 00000000:40:00.0 Off |                    0 |
| N/A   42C    P0              66W / 400W |  19151MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA A100-SXM4-40GB          On  | 00000000:8B:00.0 Off |                    0 |
| N/A   41C    P0              60W / 400W |  19151MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA A100-SXM4-40GB          On  | 00000000:90:00.0 Off |                    0 |
| N/A   43C    P0              64W / 400W |  19151MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA A100-SXM4-40GB          On  | 00000000:BB:00.0 Off |                    0 |
| N/A   43C    P0              63W / 400W |  19151MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA A100-SXM4-40GB          On  | 00000000:C1:00.0 Off |                    0 |
| N/A   42C    P0              63W / 400W |   8255MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
markschoene commented 7 months ago

Just to mention, I am already happy about helpful comments how I can try to debug this myself. In the process of debugging, I'd like to inspect the XLA HLO results to see if it is compiled into a serial computation or if XLA recognizes the scan as a parallel operation. Therefore, I specify the flags

XLA_FLAGS="--xla_dump_to=my/file/path --xla_dump_hlo_as_dot=true --xla_dump_fusion_visualization=true"

Is there a way to make XLA print human readable function names? Currently, the output reads for example module_0012.jit__unnamed_function. From the operations, I can guess which kernel this refers to, but it would make debugging much easier if the operations could be named.

markschoene commented 7 months ago

Looking into the HLO output, I am quite sure about the correspondence between graphs and functions in the code.

HLO for T=4 recurrence steps

HLO for T=128 recurrence steps

For T=128, the striking difference between $\frac{\partial h{128}}{\partial W}$ and $\frac{\partial h{128}}{\partial \lambda}$ with VJP is the much larger number of operations that seem to be required to compute the vjp.