google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.98k stars 632 forks source link

nnx.jit(aux_fn) is slower than directly using nnx.jit(model.__call__) #4218

Open JunhongXu opened 1 day ago

JunhongXu commented 1 day ago

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

Problem you have encountered:

nnx.jit(aux_fn) is slower than directly using nnx.jit(model.__call__), where aux_fn is defined by

def aux_fn(model, x):
    return model(x)

From my understanding, I found that using an auxiliary function with nnx.jit seems a common practice and is required if we want to modify the internal state of the model (https://github.com/google/flax/discussions/3998). However, it seems slower than directly wrapping the model.__call__ function using nnx.jit.

See the colab link below to reproduce.

Steps to reproduce:

Colab link: https://colab.research.google.com/drive/1cGpcaBaJABUxhZuywgLZELZRwFsT5zve?usp=sharing

For completeness, I also copy the code here

import time
import jax
from flax import nnx as nnx

class MLP(nnx.Module):
    def __init__(self, din: int, dout: int, rngs: nnx.Rngs) -> None:
        # super().__init__()
        self.fc1 = nnx.Linear(din, 128, rngs=rngs)
        self.fc2 = nnx.Linear(128, 128, rngs=rngs)
        self.fc3 = nnx.Linear(128, 128, rngs=rngs)
        self.out = nnx.Linear(128, dout, rngs=rngs)

    def __call__(self, x):
        x = self.fc1(x)
        x = nnx.relu(x)
        x = self.fc2(x)
        x = nnx.relu(x)
        x = self.fc3(x)
        x = nnx.relu(x)
        x = self.out(x)
        return x

def nn_forward(model, x):
    return model, x

def benchmark_jax():
    rngs = nnx.Rngs(0)
    din, dout = 29, 7  # Example dimensions
    mlp = MLP(din, dout, rngs)
    nn_forward_call_no_aux = nnx.jit(mlp.__call__)

    # Prepare data
    x = jax.random.normal(rngs(), shape=(1, din))
    num_iterations = 1000
    warmup_iters = 100

    for _ in range(warmup_iters):
        _ = nn_forward_call_no_aux(x)

    start_time = time.time()
    for _ in range(num_iterations):
        _ = nn_forward_call_no_aux(x)
    end_time = time.time()

    print(f"JAX forward pass time for {num_iterations} iterations: {end_time - start_time:.5f} seconds")
    print(f"JAX forward pass average time: {(end_time - start_time) / num_iterations:.5f} seconds")

    print("-------------------")
    nn_forward_jit = nnx.jit(nn_forward)
    for _ in range(warmup_iters):
        _ = nn_forward_jit(mlp, x)

    start_time = time.time()
    for _ in range(num_iterations):
        _ = nn_forward_jit(mlp, x)
    end_time = time.time()
    print(f"JAX forward pass time while using auxiliary functions for {num_iterations} iterations: {end_time - start_time:.5f} seconds")
    print(f"JAX forward pass average while using auxiliary functions time: {(end_time - start_time) / num_iterations:.5f} seconds")

The outputs using a RTX 4090 are:

JAX forward pass time for 1000 iterations: 0.10531 seconds
JAX forward pass average time: 0.00011 seconds
-------------------
JAX forward pass time while using auxiliary functions for 1000 iterations: 0.59596 seconds
JAX forward pass average while using auxiliary functions time: 0.00060 seconds
cgarciae commented 22 hours ago

mlp.__call__ is not recommended as you are passing self as a capture. Try MLP.__call__ and passing mlp as the first input.

cgarciae commented 21 hours ago

Just to clarify, what is happening is that mlp.__call__ is not traversing self so its faster, a lot faster in this case. We are going to be developing a Rust extension (see #4196) so in the future nnx.jit should be fast. For now consider using this pattern to remove the python overhead.

cgarciae commented 21 hours ago

I've created this mini guide to clarify the situation around performance: #4224.