Closed chan-mander closed 11 months ago
It might be helpful to have a minimum reproducible example. I see some brax
usage as well so giving a short snippet of code that can be copied and pasted into colab and work to show the error is helpful when debugging.
Also, as a minor note, you can get better syntax highlighting and formatting by doing everything in a "python newline code newline
"
E.g.
@partial(jax.jit, static_argnames=('timeSteps'))
def loss_fn(model: eqx.Module, system: brax.System, state: brax.State, timeSteps: int = 2000):
def step(i:int, carry: tuple):
system_s, state_s, model_s, positions_s, velocities_s, forces_s = carry
positions_s = positions_s.at[i,:].add(state_s.x.pos[0])
velocities_s = velocities_s.at[i,:].add(state_s.qd)
x = jnp.array([state_s.x.pos[0][0], state_s.qd[0]])
force = model_s(x.transpose())
forces_s = forces_s.at[i].add(force)
state_s = pipeline.step(system_s, state_s, force)
return (system_s, state_s, model_s, positions_s, velocities_s, forces_s)
positions = jnp.zeros((timeSteps, 3), dtype=jnp.float32)
velocities = jnp.zeros((timeSteps, 3), dtype=jnp.float32)
forces = jnp.zeros((timeSteps, 1), dtype=jnp.float32)
carry = (system, state, model, positions, velocities, forces)
system, state, _, positions, velocities, forces = jax.lax.fori_loop(0,timeSteps,step,carry)
targetposition = jnp.zeros((timeSteps, 1))
targetVelocities = jnp.zeros((timeSteps, 1))
loss_value = jnp.sum(1.0*jnp.abs(positions[:,0]-targetposition) + 1.0*jnp.abs(velocities[:,0]-targetVelocities)) / timeSteps
return loss_value, (positions, velocities, forces, state)
See also https://docs.kidger.site/equinox/api/debug for advice on common sources of NaNs, and some of the tools JAX provides to debug them,
Hello,
Thank you for the link, I am currently looking through the debug tools and trying to use them. I am still having this issue, here is a minimum reproducible example, any insights you may have would be helpful:
import jax
import jax.numpy as jnp
from jax.config import config
import brax
from brax.io import mjcf
from brax.generalized import pipeline
import equinox as eqx
import optax
from functools import partial
class SimpleModel(eqx.Module):
layers: list # list of the layers in the NN
extra_bias: jax.Array # extra bias on the output
def __init__(self, key):
key1, key2, key3 = jax.random.split(key, 3)
# layers contain trainable parameters
self.layers = [
eqx.nn.Linear(2, 4, key=key1),
eqx.nn.Linear(4, 4, key=key2),
eqx.nn.Linear(4, 1, key=key3)
]
# extra_bias is also a trainable parameter
self.extra_bias = jax.numpy.ones(1)
@jax.jit
def __call__(self, x):
for layer in self.layers[:-1]:
x = jax.nn.elu(layer(x))
return self.layers[-1](x) + self.extra_bias
@jax.jit
def loss(model: SimpleModel, system: brax.System, state: brax.State):
goal_position_x = 0.0
x = jnp.array([state.x.pos[0][0], state.qd[0]])
force = model(x.transpose())
state = pipeline.step(system, state, force)
loss_value = 1.0*(state.x.pos[0][0] - goal_position_x)**2
return loss_value, state
def main():
key = jax.random.PRNGKey(seed=0)
key, subkey = jax.random.split(key, 2)
model = SimpleModel(subkey)
lr = 0.001
opt = optax.adam(lr)
opt_state = opt.init(model)
system: brax.System = mjcf.loads(
"""<mujoco>
<option timestep="0.001"/>
<worldbody>
<body pos="{x} {y} {z}">
<joint name="slider" type="slide" axis="1 0 0" limited="false"/>
<geom size="0.5 0.5 0.5" type="box" mass="1"/>
</body>
<geom name="floor" size="40 40 40" type="plane"/>
</worldbody>
<actuator>
<motor ctrllimited="true" ctrlrange="-10 10" gear="10" joint="slider" name="slide"/>
</actuator>
</mujoco>""".format(x=1.0, y=0.0, z=0.5))
qd = jnp.array([0.0])
state = jax.jit(pipeline.init)(system, system.init_q, qd)
for i in range(30):
l, grad = eqx.filter_value_and_grad(loss, has_aux=True)(model, system, state)
print("Batch: ", i+1, " Loss: ", l[0])
print("Gradients of layer 0 weights:\n", grad.layers[0].weight)
state = l[1]
updates, opt_state = opt.update(grad, opt_state)
model = optax.apply_updates(model, updates)
if __name__ == "__main__":
main()
Seemed to be some issue with using jax.jit instead of filter jit. This works for me:
import jax
import jax.numpy as jnp
from jax.config import config
import brax
from brax.io import mjcf
from brax.generalized import pipeline
import equinox as eqx
import optax
from functools import partial
class SimpleModel(eqx.Module):
layers: list # list of the layers in the NN
extra_bias: jax.Array # extra bias on the output
def __init__(self, key):
key1, key2, key3 = jax.random.split(key, 3)
# layers contain trainable parameters
self.layers = [
eqx.nn.Linear(2, 4, key=key1),
eqx.nn.Linear(4, 4, key=key2),
eqx.nn.Linear(4, 1, key=key3)
]
# extra_bias is also a trainable parameter
self.extra_bias = jax.numpy.ones(1)
def __call__(self, x):
for layer in self.layers[:-1]:
x = jax.nn.elu(layer(x))
return self.layers[-1](x) + self.extra_bias
def loss(model: SimpleModel, system: brax.System, state: brax.State):
goal_position_x = 0.0
x = jnp.array([state.x.pos[0][0], state.qd[0]])
force = model(x.transpose())
state = pipeline.step(system, state, force)
loss_value = 1.0*(state.x.pos[0][0] - goal_position_x)**2
return loss_value, state
def main():
key = jax.random.PRNGKey(seed=0)
key, subkey = jax.random.split(key, 2)
model = SimpleModel(subkey)
lr = 0.001
opt = optax.adam(lr)
opt_state = opt.init(model)
system: brax.System = mjcf.loads(
"""<mujoco>
<option timestep="0.001"/>
<worldbody>
<body pos="{x} {y} {z}">
<joint name="slider" type="slide" axis="1 0 0" limited="false"/>
<geom size="0.5 0.5 0.5" type="box" mass="1"/>
</body>
<geom name="floor" size="40 40 40" type="plane"/>
</worldbody>
<actuator>
<motor ctrllimited="true" ctrlrange="-10 10" gear="10" joint="slider" name="slide"/>
</actuator>
</mujoco>""".format(x=1.0, y=0.0, z=0.5))
qd = jnp.array([0.0])
state = jax.jit(pipeline.init)(system, system.init_q, qd)
for i in range(5):
l, grad = eqx.filter_jit(eqx.filter_value_and_grad(loss, has_aux=True))(model, system, state)
print("Batch: ", i+1, " Loss: ", l[0])
print("Gradients of layer 0 weights:\n", grad.layers[0].weight)
state = l[1]
updates, opt_state = opt.update(grad, opt_state)
model = optax.apply_updates(model, updates)
if __name__ == "__main__":
main()
Batch: 1 Loss: 1.0000165
Gradients of layer 0 weights:
[[-8.9743912e-07 -0.0000000e+00]
[-3.2213477e-06 -0.0000000e+00]
[-3.2706109e-07 -0.0000000e+00]
[ 9.9661854e-08 0.0000000e+00]]
Batch: 2 Loss: 1.0000491
Gradients of layer 0 weights:
[[-9.2726850e-07 -7.6194153e-09]
[-3.2008952e-06 -2.6301928e-08]
[-3.6686370e-07 -3.0145388e-09]
[ 1.3357227e-07 1.0975706e-09]]
Batch: 3 Loss: 1.000098
Gradients of layer 0 weights:
[[-9.5715563e-07 -1.5666025e-08]
[-3.1808418e-06 -5.2061694e-08]
[-4.0678665e-07 -6.6579871e-09]
[ 1.6731423e-07 2.7384772e-09]]
Batch: 4 Loss: 1.0001631
Gradients of layer 0 weights:
[[-9.8711382e-07 -2.4135012e-08]
[-3.1612092e-06 -7.7291816e-08]
[-4.4685981e-07 -1.0925758e-08]
[ 2.0088858e-07 4.9117412e-09]]
Batch: 5 Loss: 1.0002439
Gradients of layer 0 weights:
[[-1.0171552e-06 -3.3021848e-08]
[-3.1420084e-06 -1.0200500e-07]
[-4.8710831e-07 -1.5813924e-08]
[ 2.3429754e-07 7.6064479e-09]]
Notably, an important thing that @lockwo has done is to put the JIT outside the grad, not inside. This is much more efficient! Ideally, the opt.update
and apply_updates
lines should also be put inside JIT, see e.g. this RNN example and its make_step
function.
That said this should just be an efficiency thing, not a correctness thing.
Okay, I'm not too sure what's going on, but here's a smaller MWE that doesn't use Equinox or Optax. I'd suggest raising this with probably the Brax authors.
# This gets NaN gradients. Removing the `jax.jit` on line 10, or removing the aux output
# from `jax.grad`, results in non-NaN gradients.
import jax
import jax.numpy as jnp
from brax.io import mjcf
from brax.generalized import pipeline
@jax.jit
def loss(weight, system, state):
x = jnp.array([state.x.pos[0][0], state.qd[0]])
force = weight @ x.transpose()
state = pipeline.step(system, state, force)
loss_value = state.x.pos[0][0]**2
return loss_value, state
def main():
key = jax.random.PRNGKey(seed=0)
key, subkey = jax.random.split(key, 2)
weight = jax.random.normal(subkey, (1, 2))
system = mjcf.loads(
"""<mujoco>
<option timestep="0.001"/>
<worldbody>
<body pos="{x} {y} {z}">
<joint name="slider" type="slide" axis="1 0 0" limited="false"/>
<geom size="0.5 0.5 0.5" type="box" mass="1"/>
</body>
<geom name="floor" size="40 40 40" type="plane"/>
</worldbody>
<actuator>
<motor ctrllimited="true" ctrlrange="-10 10" gear="10" joint="slider" name="slide"/>
</actuator>
</mujoco>""".format(x=1.0, y=0.0, z=0.5))
qd = jnp.array([0.0])
state = jax.jit(pipeline.init)(system, system.init_q, qd)
grad, _ = jax.jit(jax.grad(loss, has_aux=True))(weight, system, state)
print("Gradients of layer 0 weights:\n", grad)
if __name__ == "__main__":
main()
I was able to solve the issue. It seems that the problem was in how I was initializing the initial state with Brax. Thank you for your help and input.
Hello,
I am fairly new to JAX and Equinox but I am running into a issue that I don't quite understand. I have the following loss function:
`@partial(jax.jit, static_argnames=('timeSteps')) def loss_fn(model: eqx.Module, system: brax.System, state: brax.State, timeSteps: int = 2000):
I try to get the sensitivity of the loss with respect to the model params by running the following line;
l, grad = eqx.filter_value_and_grad(loss_fn, has_aux=True)(model, system, newState, timeSteps)
When I run this with Jax jit compile I get NaN for the gradients, but when disable the jit compile I am getting values for gradients. Can anyone help me shed some light as to why this is happening, and how I can compute gradients with the jit compiler?