Open sean-escola opened 3 years ago
@srvasude Thanks for picking up this bug. Do you have a sense of the best place in the stack to make a fix?
Anybody able to take a look at this?
Hi, Taking a look at this. Will give you an update today / tomorrow.
@srvasude Any update? Thanks!
The API for lbfgs_minimize reports that any arbitrary number of batch dimensions are supported, but this is not the case. The failure is the use of
where(cond, tval, fval)
on line 43 ofhager_zhang_lib.py
. Ifcond
is a vector thenwhere
interprets it as referencing the outer dimension oftval
andfval
. But here,cond
will have multiple dimensions (equal to the batch dimensions), sowhere
errors.The solution is to call
where
withcond[..., None]
(i.e., so thatcond
is broadcastable withtval
andfval
) if dimensionality ofcond
is greater than 1. This is needed for the L-BFGS (and BFGS) optimizer to work, but I'm not sure where in the stack is the best place to fix this. (I hackedhager_zhang_lib.py
directly to make it work for me).Here's code to replicate the problem (based on the sample in the API):