google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 166 forks source link

Performance issue: multi_transform and set_to_zero don't prevent computation #993

Open YanisJouanaud opened 1 week ago

YanisJouanaud commented 1 week ago

Hi,

I am trying to optimize a couple of parameters with different sizes and associated optimizers in a jax.lax.while_loop .

In order to achieve that, I used multi_transform and set_to_zero to select the parameter I want to optimize and with which optimizer, and I used a jax.lax.cond to alternate between the 2 parameters/optimizers based on step number.

However when running the code, it appears that the runtime of every iteration is capped by the runtime of the most expensive one. I wonder why so and how to prevent this behaviour.

This is similar to #477 in some way, but the solution does not fit to my case.

Here's a (quite heavy) MWE of the problem using #350 resolution as a base, and where we take same sized parameters to try to pinpoint the issue:

import os

# FORCE CPU
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import jax
import jax.numpy as jnp
from typing import NamedTuple, Any
import optax
from time import perf_counter as chrono
from functools import partial

class AlternateTxState(NamedTuple):
    step: jnp.ndarray
    tx1_state: Any
    tx2_state: Any

def alternate_tx(tx1, tx2, every1, every2):
    def init_fn(params):
        return AlternateTxState(
            step=jnp.zeros([], dtype=jnp.int32),
            tx1_state=tx1.init(params),
            tx2_state=tx2.init(params),
        )

    def _update_tx1(updates, state, params=None):
        new_updates, new_state = tx1.update(updates, state.tx1_state, params)
        return new_updates, state._replace(step=state.step + 1, tx1_state=new_state)

    def _update_tx2(updates, state, params=None):
        new_updates, new_state = tx2.update(updates, state.tx2_state, params)
        return new_updates, state._replace(step=state.step + 1, tx2_state=new_state)

    def update_fn(updates, state, params=None):
        return jax.lax.cond(
            state.step % (every1 + every2) >= every1,
            _update_tx2,
            _update_tx1,
            updates,
            state,
            params,
        )

    return optax.GradientTransformation(init_fn, update_fn)

def optimizer_alternate(first_params, second_params, tx1, tx2, evry1, evry2):
    return alternate_tx(
        optax.multi_transform(
            {"tx": tx1, "zero": optax.set_to_zero()},
            {first_params: "tx", second_params: "zero"},
        ),
        optax.multi_transform(
            {"tx": tx2, "zero": optax.set_to_zero()},
            {first_params: "zero", second_params: "tx"},
        ),
        evry1,
        evry2,
    )

def test(state, opt):
    gradients = jax.tree.map(jnp.ones_like, param_dict)

    updates, state = opt.update(gradients, state)
    jax.block_until_ready(updates)
    return state

def make_n_updates(n, state, opt):

    # compilation (optimizer `opt` if static)
    jtest = jax.jit(test, static_argnames="opt")
    _ = jtest(state, opt)

    for i in range(n):
        # benchmark
        start = chrono()
        state = jtest(state, opt)  # state is updated
        speed = chrono() - start
        if i % 2 == 0:
            used_opt = "tx1"
        else:
            used_opt = "tx2"
        print(f"Time for update {i} with {used_opt=}: {speed}")

# Create two distinct set of parameters
# they will be updated by 2 different gradient transforms
key = jax.random.PRNGKey(2)
key, subkey1, subkey2 = jax.random.split(key, 3)

n_params = int(2e6)
shape1 = (n_params,)
shape2 = shape2  # can be different, which results in some other related weird behaviours
param1 = jax.random.uniform(subkey1, shape1)
param2 = jax.random.uniform(subkey2, shape2)

param_dict = {"param1": param1, "param2": param2}

print(
    f"---- BENCHMARK : timing of optax.update() for set_to_zeros and adam. Parameter size is {n_params=}."
)

@partial(jax.jit, static_argnames=["tx"])
def jbenchmark(state, tx):
    # here only gradient is of same shape as params1 
    gradients = jax.tree.map(jnp.ones_like, param1)
    updates, state = tx.update(gradients, state)
    jax.block_until_ready(updates)
    return state

tx_zero = optax.set_to_zero()
state_zero = tx_zero.init(param1)

jbenchmark(state_zero, tx_zero)  # compile
start = chrono()
_ = jbenchmark(state_zero, tx_zero)
speed_zero = chrono() - start

tx_adam = optax.adam(learning_rate=1e-3)
state_adam = tx_adam.init(param1)

jbenchmark(state_adam, tx_adam)  # compile
start = chrono()
_ = jbenchmark(state_adam, tx_adam)
speed_adam = chrono() - start

print(
    f"Speed for optax.set_to_zeros() is {speed_zero} \nSpeed for optax.adam(lr=1e-3) is {speed_adam}"
)

print(
    "---- Scenario 1: two fast gradient transform (tx1 and tx2 are optax.set_to_zero)"
)
tx1 = optax.set_to_zero()
tx2 = optax.set_to_zero()
opt = optimizer_alternate("param1", "param2", tx1, tx2, 1, 1)
state = opt.init(param_dict)
make_n_updates(4, state, opt)

print("---- Scenario 2: two slower gradient transform (tx1 and tx2 are optax.adam)")
tx1 = optax.adam(learning_rate=1e-3)
tx2 = optax.adam(learning_rate=1e-3)
opt = optimizer_alternate("param1", "param2", tx1, tx2, 1, 1)
state = opt.init(param_dict)
make_n_updates(4, state, opt)

print("---- Scenario 3: one 'slow' (tx1=adam) and one faster (tx2=set_to_zero)")

tx1 = optax.adam(learning_rate=1e-3)
tx2 = optax.set_to_zero()
opt = optimizer_alternate("param1", "param2", tx1, tx2, 1, 1)
state = opt.init(param_dict)
make_n_updates(4, state, opt)

print(
    "---- Scenario 4: sanity check, just reverse the order of tx1 and tx2 to see if there is any difference"
)

tx1 = optax.set_to_zero()
tx2 = optax.adam(learning_rate=1e-3)
opt = optimizer_alternate("param1", "param2", tx1, tx2, 1, 1)
state = opt.init(param_dict)
make_n_updates(4, state, opt)

Outputs on my machine (CPU only) :

---- BENCHMARK : timing of optax.update() for set_to_zeros and adam. Parameter size is n_params=2000000.
Speed for optax.set_to_zeros() is 4.5574000068882015e-05 
Speed for optax.adam(lr=1e-3) is 0.0031374410000353237
---- Scenario 1: two fast gradient transform (tx1 and tx2 are optax.set_to_zero)
Time for update 0 with used_opt='tx1': 0.00011168800028826809
Time for update 1 with used_opt='tx2': 0.00010074600049847504
Time for update 2 with used_opt='tx1': 8.024300041142851e-05
Time for update 3 with used_opt='tx2': 7.61519995648996e-05
---- Scenario 2: two slower gradient transform (tx1 and tx2 are optax.adam)
Time for update 0 with used_opt='tx1': 0.04269762199965044
Time for update 1 with used_opt='tx2': 0.060171784999511146
Time for update 2 with used_opt='tx1': 0.015525636999882408
Time for update 3 with used_opt='tx2': 0.06336476400065294
---- Scenario 3: one 'slow' (tx1=adam) and one faster (tx2=set_to_zero)
Time for update 0 with used_opt='tx1': 0.012170193000201834
Time for update 1 with used_opt='tx2': 0.025405665999642224
Time for update 2 with used_opt='tx1': 0.0057674279996717814
Time for update 3 with used_opt='tx2': 0.004441808000592573
---- Scenario 4: sanity check, just reverse the order of tx1 and tx2 to see if there is any difference
Time for update 0 with used_opt='tx1': 0.012651638000534149
Time for update 1 with used_opt='tx2': 0.031053042000166897
Time for update 2 with used_opt='tx1': 0.004927424999550567
Time for update 3 with used_opt='tx2': 0.00997542599998269

We also note that the runtime is quite different from what the Benchmark would lead us to expect.