lanl-ansi / MathOptAI.jl

Embed trained machine learning predictors in JuMP
https://lanl-ansi.github.io/MathOptAI.jl/
Other
29 stars 1 forks source link

[GrayBox] add Hessian support #90

Closed odow closed 1 month ago

odow commented 1 month ago

Suggested by @pulsipher in #82

As discussed in https://github.com/lanl-ansi/MathOptAI.jl/issues/70, it would be very nice to automate the embedding of NNs (and other predictors) as black-box functions that are treated as nonlinear operators. In my research focusing on using smooth NNs in nonlinear optimal control formulations, we have found that treating the NN as an operator that gets its derivatives from the ML environment (e.g., using torch.func from PyTorch) significantly outperforms embedding the NN as algebraic constraints (benchmarking OMLT against using PyNumero's greybox API).

Naturally, JuMP's nonlinear operator API is scalarized, so I am not sure how well it will work for predictors with many inputs and outputs. This definitely motivates the milestone to add vectorized operator support in JuMP.

To which I replied

For black-box outputs, we could automate wrapping @operator and building the appropriate derivatives. And for vector-valued, we could also automate the memoization stuff.

odow commented 1 month ago

I have a prototype, but it needs https://github.com/jump-dev/MathOptInterface.jl/issues/2534 for input=1 models.

odow commented 1 month ago

96 implements most of what we want out of this.

@pulsipher would like the ability to add Hessians to the nonlinear callback

odow commented 1 month ago

We also discussed the ability to re-use the operator for different inputs.

But I don't know that I like it because it would interfere with the cache. I think we should try the existing behavior before thinking about improvements.

We might also consider the ::Matrix input to batch calls.

pulsipher commented 1 month ago

But I don't know that I like it because it would interfere with the cache. I think we should try the existing behavior before thinking about improvements.

This is a notable limitation of having to use memoization which needs the cache. I believe that https://github.com/jump-dev/MathOptInterface.jl/issues/2402 would solve this problem. I think it is intuitive to have a nonlinear operator that isn't tied to particular variable inputs.

The other thing I wonder is how well the memoized nonlinear operators that depend on splatted inputs will perform as the number the of inputs and outputs become larger (say something on the order of 100 or 1000).

odow commented 1 month ago

Yip. https://github.com/jump-dev/MathOptInterface.jl/issues/2402 would fix this. But that's a much more complicated issue :smile:

The other thing I wonder is how well the memoized nonlinear operators that depend on splatted inputs will perform as the number the of inputs and outputs become larger

Your guess is as good as mine. Probably poorly. But we can look to improve performance once we have some examples.