google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
933 stars 66 forks source link

Constraint violation causes L-BGFS-B to fail #590

Open gulls-on-parade opened 7 months ago

gulls-on-parade commented 7 months ago

I believe the line search internally used by jaxopt.LBFGSB is not respecting the bounds that are passed here, causing the objective function to generate NaNs and the overall optimization problem to fail. I am unsure if this a bug, or if I am doing something wrong. Any guidance is much appreciated.

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count={}'.format(os.cpu_count())
import jax as jx
from jax import jit
import jax.numpy as jnp
import jaxopt

#%% Helper functions

@jit
def normal_vol(strike, atmf, t, alpha, beta, rho, nu):
    eps = 1e-07  # Numerical tolerance
    f_av = jnp.sqrt(atmf * strike)

    fmkr = jnp.select([(jnp.abs(atmf - strike) > eps) & (jnp.abs(1 - beta) > eps), 
                       (jnp.abs(atmf - strike) > eps) & (jnp.abs(1 - beta) <= eps), 
                       jnp.abs(atmf - strike) <= eps],
                      [(1 - beta) * (atmf - strike) / (atmf**(1 - beta) - strike**(1 - beta)),
                       (atmf - strike) / jnp.log(atmf / strike),
                       strike**beta],
                      jnp.nan)

    zeta = nu * (atmf - strike) / (alpha * f_av**beta)

    zxz = jnp.select([jnp.abs(zeta) > eps, 
                      jnp.abs(zeta) <= eps],
                     [zeta / jnp.log(jnp.abs(((1 - 2 * rho * zeta + zeta**2)**.5 + zeta - rho) / (1 - rho))),
                      1.],
                     jnp.nan)

    a = - beta * (2 - beta) * alpha**2 / (24 * f_av**(2 - 2 * beta))
    b = rho * alpha * nu * beta / (4 * f_av**(1 - beta))
    c = (2 - 3 * rho**2) * nu**2 / 24

    vol = alpha * fmkr * zxz * (1 + (a + b + c) * t)

    return vol

@jit
def _obj(params, args):
    """Objective function to minimize the squared error between implied and model vols."""
    expiry, tail, strikes, vols, atmf, beta = args
    alpha, rho, nu = params
    vol_fitted = jx.vmap(normal_vol, (0, None, None, None, None, None, None))(strikes, atmf, expiry, alpha, beta, rho, nu)
    error = (vol_fitted - vols) * 1e4
    return jnp.sum(error**2)

#%% Example problem

data = [(0.09041095890410959,
  0.2465753424657534,
  jnp.array([0.0824076, 0.0849076, 0.0874076, 0.0899076, 0.0924076, 0.0949076,
         0.0974076, 0.0999076, 0.1024076, 0.1049076, 0.1074076, 0.1099076,
         0.1124076, 0.1149076, 0.1174076, 0.1199076, 0.1224076, 0.1249076,
         0.1274076, 0.1299076, 0.1324076, 0.1349076, 0.1374076, 0.1399076,
         0.1424076]),
  jnp.array([0.02100495, 0.02000676, 0.01897691, 0.01791351, 0.016814,
         0.01567488, 0.01449142, 0.0132571 , 0.01196264, 0.0105943 ,
         0.00913049, 0.00753422, 0.00573621, 0.00368666, 0.00298916,
         0.00351651, 0.00417858, 0.00485768, 0.00553383, 0.00620241,
         0.00686251, 0.00751431, 0.00815832, 0.00879512, 0.00942527]),
  0.11240760359238675,
  0.25),
 (0.09041095890410959,
  1.0027397260273974,
  jnp.array([0.07611851, 0.07861851, 0.08111851, 0.08361851, 0.08611851,
         0.08861851, 0.09111851, 0.09361851, 0.09611851, 0.09861851,
         0.10111851, 0.10361851, 0.10611851, 0.10861851, 0.11111851,
         0.11361851, 0.11611851, 0.11861851, 0.12111851, 0.12361851,
         0.12611851, 0.12861851, 0.13111851, 0.13361851, 0.13611851]),
  jnp.array([0.02571163, 0.02466922, 0.02359377, 0.02248411, 0.02133859,
         0.02015503, 0.01893064, 0.01766194, 0.01634481, 0.01497479,
         0.01354828, 0.01206712, 0.01055505, 0.00911653, 0.00807032,
         0.00778549, 0.00810574, 0.00870589, 0.0094221 , 0.01018791,
         0.01097495, 0.01177004, 0.01256661, 0.01336128, 0.01415225]),
  0.10611850901102435,
  0.25),
 (0.09041095890410959,
  2.0027397260273974,
  jnp.array([0.06970405, 0.07220405, 0.07470405, 0.07720405, 0.07970405,
         0.08220405, 0.08470405, 0.08720405, 0.08970405, 0.09220405,
         0.09470405, 0.09720405, 0.09970405, 0.10220405, 0.10470405,
         0.10720405, 0.10970405, 0.11220405, 0.11470405, 0.11720405,
         0.11970405, 0.12220405, 0.12470405, 0.12720405, 0.12970405]),
  jnp.array([0.02641612, 0.02545857, 0.02447167, 0.02345581, 0.02241125,
         0.02133829, 0.02023758, 0.01911054, 0.01796036, 0.01679381,
         0.01562486, 0.01448212, 0.01342213, 0.01254578, 0.01198868,
         0.01184018, 0.01206377, 0.01254688, 0.01318733, 0.01391874,
         0.01470226, 0.01551549, 0.0163453 , 0.01718381, 0.01802615]),
  0.09970405414511939,
  0.25)]

x0 = jnp.array([0.01, 0.00, 0.10])
bounds = (jnp.array([0.0001, -0.9999, 0.0001]), jnp.array([999, 0.9999, 999]))
args = data[0]

# This fails, as the objective function is producing nans when the step size immediately violates bounds as part of the implicit differentiation
solver = jaxopt.LBFGSB(fun=_obj)
results = solver.run(x0, bounds=bounds, args=args)

# However the objective function evaluates properly at x0
_obj(x0, args)
charles-zhng commented 7 months ago

I might be having the same issue too, please let me know if you learn something new! For now I am clipping to the bounds myself in my loss function.