Open markschoene opened 7 months ago
Just to mention, I am already happy about helpful comments how I can try to debug this myself. In the process of debugging, I'd like to inspect the XLA HLO results to see if it is compiled into a serial computation or if XLA recognizes the scan as a parallel operation. Therefore, I specify the flags
XLA_FLAGS="--xla_dump_to=my/file/path --xla_dump_hlo_as_dot=true --xla_dump_fusion_visualization=true"
Is there a way to make XLA print human readable function names? Currently, the output reads for example module_0012.jit__unnamed_function
. From the operations, I can guess which kernel this refers to, but it would make debugging much easier if the operations could be named.
Looking into the HLO output, I am quite sure about the correspondence between graphs and functions in the code.
HLO for T=4 recurrence steps
HLO for T=128 recurrence steps
For T=128, the striking difference between $\frac{\partial h{128}}{\partial W}$ and $\frac{\partial h{128}}{\partial \lambda}$ with VJP is the much larger number of operations that seem to be required to compute the vjp.
Description
Thanks for considering this issue. My research requires me to compute the jacobian of a recurrent neural network with respect to a set of quantities. The minimal example below considers a linear recurrence $xt = \lambda \odot x{t-1} + B u_t$ as found in recent deep state-space models such as S5. Here, $x_t, \lambda \in\mathbb{C}^n, u_t\in\mathbb{R}^m, B\in\mathbb{C}^{n\times m}$.
Problem Description
Since recursion relations can be formulated as associative operators, they can be parallelized using
lax.associative_scan
.My issue arises when using AD to compute the following quantities
I would expect these operations to
associative_scan
as the forward pass is also parallelizedWhen measuring the compute time on A100 (40GB), I find that particularly the derivative $\frac{\partial x_t}{\partial \lambda}$ is much slower than the other ones. The compute time of all derivatives increases linearly with sequence length (measured many different specifications, but the example below does it from $t=32,\dots,512$). Yet, the derivative $\frac{\partial x_t}{\partial \lambda}$ has much steeper slope as shown in the image below.
The parameter $\lambda$ is fed to the
binary_operator
asa_i, a_j
. If I replace the return statement withreturn jax.numpy.ones_like(a_j) * jax.numpy.ones_like(a_i), a_j * bu_i + bu_j
such that AD doesn't need to trace the value of $\lambda$ through the first argument of thebinary_operator
, the severe differences in compute times vanish.For completeness, the example below contains forward and backward mode AD results for networks with 1 and 2 layers.
I am now wondering
Figure: Note the slope of the linear fits in brackets is an order of magnitude larger for the jacobian (scan), i.e. w.r.t. the parameter iteratively used in the scan: $\lambda$
Code to reproduce the measurement
The code below also allows to measure the same program implemented with
lax.scan
and a simple python for loop. These are commented out below. Just uncomment them to also measure these quantities. Compile times for the python for loop will be quite significant for larger sequence length.What jax/jaxlib version are you using?
jax v0.4.20 jaxlib v0.4.20+cuda12.cudnn89
Which accelerator(s) are you using?
GPU
Additional system info?
Python 3.10.8, Linux
NVIDIA GPU info