SciProgCentre / kmath

Kotlin mathematics extensions library
648 stars 55 forks source link

Integration with kotlingrad #149

Closed altavir closed 3 years ago

altavir commented 3 years ago

https://github.com/breandan/kotlingrad

The idea is to connect kotlingrad automatic differentiation done in kotlngrad via MST expressions. One should be able to transform MST expressions to kotlingrad node-graph and back again. This way we will be able to perform autodif on any algebra. The integration possibly will require some changes in kotlingrad API.

cc @breandan

breandan commented 3 years ago

Hey @altavir, thanks for the suggestion! It's been on my to-do list for a while and is something I've meant to contribute but haven't been able to find the time. Philosophically, the APIs are similar enough, so translating or integrating KG should not require a tremendous amount of effort. To implement AD on scalar expressions just requires pattern matching and replacing the MST nodes with the correct derivative rule, for elementary functions the implementation is probably ~20 lines long (see here for what that looks like in KG).

I might not be able to get to this very soon, but in case anyone wants to start the ball rolling, happy to leave feedback!

altavir commented 3 years ago

@CommanderTvis added a basic connector, but I am not sure I like the user experience and API design. I would like to have a more seamless experience. So feedback is appreciated. @rgrit91?

grinisrit commented 3 years ago

I think, because this is such a performance sensitive part, one does need to write benchmark tests initially to keep track of any overhead that might be introduced with another layer of abstraction. And also keep an eye on the future, if benchmarks against other autograd libraries will be required.

altavir commented 3 years ago

I am not sure that automatic differentiation logic itself is performance sensitive. It is symbolic after all and 8s done only once. The computation of the expression could be heavy, but it depends on the algebra.

CommanderTvis commented 3 years ago

MST computing depends only on the complexity of algebra provided functions and on the interpreter overhead which is minimal (at least, for most workloads) for ASM generated expressions.

grinisrit commented 3 years ago

For some applications I have in mind, the DFG might need to get regenerated with every sample/step. And in any case, I think for a library like kmath adding some benchmarks would be nice, and I am happy to contribute with that once I find a bit of time.

altavir commented 3 years ago

I think @rgrit91 talks about tensorflow type of gradient computation which is not quite the same. It operates with large matrices from the beginning. In our case we are talking about two phases: actual autodiff, which is done once using now graph and is not performance critical, the second stage is using the expression to computer derivative value. It could be time-consuming for large matrices, but it depends on algebra, not on autodif.

breandan commented 3 years ago

The primary focus of KG is usability and type safety. It is important to have benchmarks, but I would be careful to avoid premature optimization. We have explored some staging optimizations and although there are plenty which could improve performance, until there is a GPU backend these optimizations are mostly negligible.

One area where we've spend a lot of effort is variable binding and invocation. It becomes tedious to write f.invoke(bindings).toDouble(), and this may throw an error. Suppose you have val f = x + y + z; val g = f(1, 2, 3). Is it possible to infer that g is an Int or Double instead of an SFun<...>? We have a prototype which works on a small alphabet. I'm not sure this is the best solution but I would just encourage you think very carefully about how the eDSL interacts with the language's type system, this is one of the most important things to get right IMO.

@CommanderTvis I took a quick look at the PR. Looks like you're off to a good start! Just so I understand, is your goal to write an AST converter or extend the KMath algebra? Based on prior discussion I know it is preferred to extend or wrap objects in the KMath type system. If it were me, I wouldn't spend a lot of effort translating KG verbatim, but just look over the README and try adapting it to MST where it makes sense. I think the ideal result would be an MST based AD/SD, but with some of the design patterns from KG. Happy to answer any specific questions about our design choices.

CommanderTvis commented 3 years ago

@breandan my current idea is to use KotlinGrad to transform MST expressions: (very conceptual syntax)

val x by MstAlgebra.symbol
val f = "x^2-4*x-1".parseMath()
val derivative = f.transformAsSFun(DoublePrecision.prototype) { it.d(x) } // returns MST
println(MstExpression(RealField, derivative).compile()("x" to 321.0))
CommanderTvis commented 3 years ago

It is also possible to even encapsulate Kotlingrad stuff and just provide a KMath-like interface of differentiation.

grinisrit commented 3 years ago

@breandan I just discovered KG and indeed it is a jewel in that sense. Just we do have plans with @altavir to run some heavy computations involving AD in the near future. I am coming from torch, so obviously interested to see any kind of comparison, in particular in performance (even just on CPU), but I guess we just need to try and see how it goes.

altavir commented 3 years ago

I thought a bit more about the integration design. As I see right now, the most important feature is the ability to compute kotlingrad expressions in kmath algebraic contexts. This way we can do analytical differentiation of expressions with arbitrary types and then compute them using a numerically optimized algorithm from a plugin (precisely what @rgrit91 wants).

So the most important workflow is the following:

  1. Generate an expression of type T using kotlingrad API
  2. Transform the resulting expression into MST
  3. Compute MST in specific algebra, which operates on type T.

It seems like all we need is already there in #150 . Now all we need is to be able to hide mst conversion from the user since he does not want to know about our implementation details. All we need to do is to add an extension on kotlingrad expression, which will compute the expression in given algebra. Like fun <T> SFun.compute(algebra: Algebra<T>): T. This way we will be able to get both kotlingrad design and performant operations with kmath integration. The autodiff part won't affect performance since it does not do actual computation.

Another use-case is usage of kotlingrad with kmath-generated expressions. I think that #150 covers MST integration, but we probably won't use MST algebra when we have much more advanced kotlingrad. The thing I have in mind is to provide an API for differentiable expressions in kmath and bring all autodiff tools under the same API. Then we can convert differentiable expressions from other frameworks and integrate them with kotlingrad under the same API. It is not urgent, I will think about it a little bit more.