tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.26k stars 1.1k forks source link

L-BFGS optimizer does not support multiple batch dimensions #1366

Open sean-escola opened 3 years ago

sean-escola commented 3 years ago

The API for lbfgs_minimize reports that any arbitrary number of batch dimensions are supported, but this is not the case. The failure is the use of where(cond, tval, fval) on line 43 of hager_zhang_lib.py. If cond is a vector then where interprets it as referencing the outer dimension of tval and fval. But here, cond will have multiple dimensions (equal to the batch dimensions), so where errors.

The solution is to call where with cond[..., None] (i.e., so that cond is broadcastable with tval and fval) if dimensionality of cond is greater than 1. This is needed for the L-BFGS (and BFGS) optimizer to work, but I'm not sure where in the stack is the best place to fix this. (I hacked hager_zhang_lib.py directly to make it work for me).

Here's code to replicate the problem (based on the sample in the API):

# A high-dimensional quadratic bowl.
ndims = 60
minimum = np.ones([ndims], dtype='float64')
scales = np.arange(ndims, dtype='float64') + 1.0

# The objective function and the gradient.
def quadratic_loss_and_gradient(x):
    return tfp.math.value_and_gradient(
        lambda x: tf.reduce_sum(scales * tf.math.squared_difference(x, minimum), axis=-1),
        x
    )

start1 = np.arange(ndims, 0, -1, dtype='float64')  # no batch dimensions
start2 = start1[None, :]  # 1 batch dimension
start3 = start1[None, None, :]  # 2 batch dimensions

# This works
optim_results = tfp.optimizer.lbfgs_minimize(
    quadratic_loss_and_gradient,
    initial_position=start1
)

# And this works
optim_results = tfp.optimizer.lbfgs_minimize(
    quadratic_loss_and_gradient,
    initial_position=start2
)

# But this fails
optim_results = tfp.optimizer.lbfgs_minimize(
    quadratic_loss_and_gradient,
    initial_position=start3
)
sean-escola commented 3 years ago

@srvasude Thanks for picking up this bug. Do you have a sense of the best place in the stack to make a fix?

sean-escola commented 3 years ago

Anybody able to take a look at this?

srvasude commented 3 years ago

Hi, Taking a look at this. Will give you an update today / tomorrow.

sean-escola commented 3 years ago

@srvasude Any update? Thanks!