patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 141 forks source link

`jit`ing large models for inference has bad compilation performance #770

Open colehaus opened 4 months ago

colehaus commented 4 months ago
class Test(eqx.Module, Generic[Float]):
    test: eqx.nn.Linear

    def __init__(self, *, key: jax.Array, dtype: type[Float], in_features: int, out_features: int):
        self.test = eqx.nn.Linear(
            in_features=in_features, out_features=out_features, use_bias=False, key=key, dtype=dtype
        )

    def __call__(self, x: ndarray[Any, Float]) -> ndarray[Any, Float]:
        return self.test(x)
for d in [1_000, 2_000, 4_000, 8_000, 16_000, 32_000, 64_000]:
    t = Test(key=jax.random.PRNGKey(0), dtype=bfloat16, in_features=d, out_features=d)
    print(d)
    with jax.log_compiles():
        eqx.filter_jit(t.__call__)(np.ones(d))
Finished tracing + transforming matmul for pjit in 0.0008780956268310547 sec
Finished tracing + transforming __call__ for pjit in 0.0028340816497802734 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[1000])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(__call__) in 0.05241560935974121 sec
1000
Finished XLA compilation of jit(__call__) in 1.2187151908874512 sec
Finished tracing + transforming matmul for pjit in 0.0005924701690673828 sec
Finished tracing + transforming __call__ for pjit in 0.002209901809692383 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[2000])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(__call__) in 0.01014566421508789 sec
2000
Finished XLA compilation of jit(__call__) in 1.414226770401001 sec
Finished tracing + transforming matmul for pjit in 0.0005934238433837891 sec
Finished tracing + transforming __call__ for pjit in 0.002273082733154297 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[4000])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(__call__) in 0.03254580497741699 sec
4000
Finished XLA compilation of jit(__call__) in 1.6822988986968994 sec
Finished tracing + transforming matmul for pjit in 0.0006303787231445312 sec
Finished tracing + transforming __call__ for pjit in 0.002046346664428711 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[8000])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(__call__) in 0.11545777320861816 sec
8000
Finished XLA compilation of jit(__call__) in 3.7346856594085693 sec
Finished tracing + transforming matmul for pjit in 0.0006031990051269531 sec
Finished tracing + transforming __call__ for pjit in 0.0021157264709472656 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[16000])]. Argument mapping: [UnspecifiedValue].
16000
Finished jaxpr to MLIR module conversion jit(__call__) in 0.615095853805542 sec
Finished XLA compilation of jit(__call__) in 11.609867572784424 sec
Finished tracing + transforming matmul for pjit in 0.0005524158477783203 sec
Finished tracing + transforming __call__ for pjit in 0.0017201900482177734 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[32000])]. Argument mapping: [UnspecifiedValue].
32000
Finished jaxpr to MLIR module conversion jit(__call__) in 2.4231040477752686 sec
Finished XLA compilation of jit(__call__) in 40.67563080787659 sec
Finished tracing + transforming matmul for pjit in 0.0005431175231933594 sec
Finished tracing + transforming __call__ for pjit in 0.0017380714416503906 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[64000])]. Argument mapping: [UnspecifiedValue].
64000
<crash here>

As you can see from the output, the jaxpr to MLIR and XLA compilation steps take longer and longer as the array dimension increases until it finally crashes during compilation. I believe this is because we're effectively closing over larger and larger values and JAX is doing work that scales with the size of the closed-over values (https://github.com/google/jax/issues/16278 may be related).

Flax avoids this issue because it directly passes the parameters/weights as arguments to the function. That perhaps seems like the best approach ATM. Is there a reasonable away to achieve behavior like that in Equinox?

(Unless I'm missing something, this is a pretty significant limitation for e.g. doing inference with language models where you'd want to JIT the sampling for a fixed model.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.1
python: 3.11.9 (main, Apr  6 2024, 17:59:24) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='nld09d8wce', release='5.19.0-45-generic', version='#46~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Jun 7 15:06:04 UTC 20', machine='x86_64')

$ nvidia-smi
Tue Jun 25 00:52:43 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   25C    P2    16W [/](https://file+.vscode-resource.vscode-cdn.net/) 300W |    265MiB [/](https://file+.vscode-resource.vscode-cdn.net/) 49140MiB |      2%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

Equinox version 0.11.4

colehaus commented 4 months ago

Ah, I realized this is a workable solution:

def call(x: ndarray[Any, Float], dynamic, static) -> ndarray[Any, Float]:
    test = eqx.combine(dynamic, static)
    return test.__call__(x)
dynamic, static = eqx.partition(t, eqx.is_array)
jax.jit(ft.partial(call, static=static))(jnp.ones(d), dynamic)
patrick-kidger commented 4 months ago

I think this is happening because you're grabbing __call__, which as a magic method isn't subject to the same bound-methods-are-PyTrees treatment as regular methods. This is the reason t is being closed over, rather than provided as an input.

Can you try doing just eqx.filter_jit(t)(np.ones(d)) instead?

colehaus commented 4 months ago

Ahh, that does make a big difference. I had gotten into the habit of doing explicit __call__ so jump-to-definition in my editor would be more useful and hadn't thought of it as anything more than a trivial syntactic transformation.