I am trying to optimize a couple of parameters with different sizes and associated optimizers in a jax.lax.while_loop .
In order to achieve that, I used multi_transform and set_to_zero to select the parameter I want to optimize and with which optimizer, and I used a jax.lax.cond to alternate between the 2 parameters/optimizers based on step number.
However when running the code, it appears that the runtime of every iteration is capped by the runtime of the most expensive one.
I wonder why so and how to prevent this behaviour.
This is similar to #477 in some way, but the solution does not fit to my case.
Here's a (quite heavy) MWE of the problem using #350 resolution as a base, and where we take same sized parameters to try to pinpoint the issue:
import os
# FORCE CPU
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import jax
import jax.numpy as jnp
from typing import NamedTuple, Any
import optax
from time import perf_counter as chrono
from functools import partial
class AlternateTxState(NamedTuple):
step: jnp.ndarray
tx1_state: Any
tx2_state: Any
def alternate_tx(tx1, tx2, every1, every2):
def init_fn(params):
return AlternateTxState(
step=jnp.zeros([], dtype=jnp.int32),
tx1_state=tx1.init(params),
tx2_state=tx2.init(params),
)
def _update_tx1(updates, state, params=None):
new_updates, new_state = tx1.update(updates, state.tx1_state, params)
return new_updates, state._replace(step=state.step + 1, tx1_state=new_state)
def _update_tx2(updates, state, params=None):
new_updates, new_state = tx2.update(updates, state.tx2_state, params)
return new_updates, state._replace(step=state.step + 1, tx2_state=new_state)
def update_fn(updates, state, params=None):
return jax.lax.cond(
state.step % (every1 + every2) >= every1,
_update_tx2,
_update_tx1,
updates,
state,
params,
)
return optax.GradientTransformation(init_fn, update_fn)
def optimizer_alternate(first_params, second_params, tx1, tx2, evry1, evry2):
return alternate_tx(
optax.multi_transform(
{"tx": tx1, "zero": optax.set_to_zero()},
{first_params: "tx", second_params: "zero"},
),
optax.multi_transform(
{"tx": tx2, "zero": optax.set_to_zero()},
{first_params: "zero", second_params: "tx"},
),
evry1,
evry2,
)
def test(state, opt):
gradients = jax.tree.map(jnp.ones_like, param_dict)
updates, state = opt.update(gradients, state)
jax.block_until_ready(updates)
return state
def make_n_updates(n, state, opt):
# compilation (optimizer `opt` if static)
jtest = jax.jit(test, static_argnames="opt")
_ = jtest(state, opt)
for i in range(n):
# benchmark
start = chrono()
state = jtest(state, opt) # state is updated
speed = chrono() - start
if i % 2 == 0:
used_opt = "tx1"
else:
used_opt = "tx2"
print(f"Time for update {i} with {used_opt=}: {speed}")
# Create two distinct set of parameters
# they will be updated by 2 different gradient transforms
key = jax.random.PRNGKey(2)
key, subkey1, subkey2 = jax.random.split(key, 3)
n_params = int(2e6)
shape1 = (n_params,)
shape2 = shape2 # can be different, which results in some other related weird behaviours
param1 = jax.random.uniform(subkey1, shape1)
param2 = jax.random.uniform(subkey2, shape2)
param_dict = {"param1": param1, "param2": param2}
print(
f"---- BENCHMARK : timing of optax.update() for set_to_zeros and adam. Parameter size is {n_params=}."
)
@partial(jax.jit, static_argnames=["tx"])
def jbenchmark(state, tx):
# here only gradient is of same shape as params1
gradients = jax.tree.map(jnp.ones_like, param1)
updates, state = tx.update(gradients, state)
jax.block_until_ready(updates)
return state
tx_zero = optax.set_to_zero()
state_zero = tx_zero.init(param1)
jbenchmark(state_zero, tx_zero) # compile
start = chrono()
_ = jbenchmark(state_zero, tx_zero)
speed_zero = chrono() - start
tx_adam = optax.adam(learning_rate=1e-3)
state_adam = tx_adam.init(param1)
jbenchmark(state_adam, tx_adam) # compile
start = chrono()
_ = jbenchmark(state_adam, tx_adam)
speed_adam = chrono() - start
print(
f"Speed for optax.set_to_zeros() is {speed_zero} \nSpeed for optax.adam(lr=1e-3) is {speed_adam}"
)
print(
"---- Scenario 1: two fast gradient transform (tx1 and tx2 are optax.set_to_zero)"
)
tx1 = optax.set_to_zero()
tx2 = optax.set_to_zero()
opt = optimizer_alternate("param1", "param2", tx1, tx2, 1, 1)
state = opt.init(param_dict)
make_n_updates(4, state, opt)
print("---- Scenario 2: two slower gradient transform (tx1 and tx2 are optax.adam)")
tx1 = optax.adam(learning_rate=1e-3)
tx2 = optax.adam(learning_rate=1e-3)
opt = optimizer_alternate("param1", "param2", tx1, tx2, 1, 1)
state = opt.init(param_dict)
make_n_updates(4, state, opt)
print("---- Scenario 3: one 'slow' (tx1=adam) and one faster (tx2=set_to_zero)")
tx1 = optax.adam(learning_rate=1e-3)
tx2 = optax.set_to_zero()
opt = optimizer_alternate("param1", "param2", tx1, tx2, 1, 1)
state = opt.init(param_dict)
make_n_updates(4, state, opt)
print(
"---- Scenario 4: sanity check, just reverse the order of tx1 and tx2 to see if there is any difference"
)
tx1 = optax.set_to_zero()
tx2 = optax.adam(learning_rate=1e-3)
opt = optimizer_alternate("param1", "param2", tx1, tx2, 1, 1)
state = opt.init(param_dict)
make_n_updates(4, state, opt)
Outputs on my machine (CPU only) :
---- BENCHMARK : timing of optax.update() for set_to_zeros and adam. Parameter size is n_params=2000000.
Speed for optax.set_to_zeros() is 4.5574000068882015e-05
Speed for optax.adam(lr=1e-3) is 0.0031374410000353237
---- Scenario 1: two fast gradient transform (tx1 and tx2 are optax.set_to_zero)
Time for update 0 with used_opt='tx1': 0.00011168800028826809
Time for update 1 with used_opt='tx2': 0.00010074600049847504
Time for update 2 with used_opt='tx1': 8.024300041142851e-05
Time for update 3 with used_opt='tx2': 7.61519995648996e-05
---- Scenario 2: two slower gradient transform (tx1 and tx2 are optax.adam)
Time for update 0 with used_opt='tx1': 0.04269762199965044
Time for update 1 with used_opt='tx2': 0.060171784999511146
Time for update 2 with used_opt='tx1': 0.015525636999882408
Time for update 3 with used_opt='tx2': 0.06336476400065294
---- Scenario 3: one 'slow' (tx1=adam) and one faster (tx2=set_to_zero)
Time for update 0 with used_opt='tx1': 0.012170193000201834
Time for update 1 with used_opt='tx2': 0.025405665999642224
Time for update 2 with used_opt='tx1': 0.0057674279996717814
Time for update 3 with used_opt='tx2': 0.004441808000592573
---- Scenario 4: sanity check, just reverse the order of tx1 and tx2 to see if there is any difference
Time for update 0 with used_opt='tx1': 0.012651638000534149
Time for update 1 with used_opt='tx2': 0.031053042000166897
Time for update 2 with used_opt='tx1': 0.004927424999550567
Time for update 3 with used_opt='tx2': 0.00997542599998269
We also note that the runtime is quite different from what the Benchmark would lead us to expect.
Hi,
I am trying to optimize a couple of parameters with different sizes and associated optimizers in a
jax.lax.while_loop
.In order to achieve that, I used
multi_transform
andset_to_zero
to select the parameter I want to optimize and with which optimizer, and I used ajax.lax.cond
to alternate between the 2 parameters/optimizers based on step number.However when running the code, it appears that the runtime of every iteration is capped by the runtime of the most expensive one. I wonder why so and how to prevent this behaviour.
This is similar to #477 in some way, but the solution does not fit to my case.
Here's a (quite heavy) MWE of the problem using #350 resolution as a base, and where we take same sized parameters to try to pinpoint the issue:
Outputs on my machine (CPU only) :
We also note that the runtime is quite different from what the Benchmark would lead us to expect.