diku-dk / futhark

:boom::computer::boom: A data-parallel functional programming language
http://futhark-lang.org
ISC License
2.41k stars 166 forks source link

Implement forward- and reverse mode AD in the interpreter #2186

Closed vox9 closed 1 month ago

vox9 commented 1 month ago

Apologies for closing the old PR; I am quite new to this.

So, as promised, I cleaned up my code a bit. That being said, more work needs doing.

Here is a list of tasks that immediately spring to mind:

  1. Clean up vjp2 and jvp2. They are ugly, and hugely inefficient. It seems implementing them is my kryptonite. I look forward to seeing them have the beauty they deserve ;)
  2. Fix the horrible time complexity of deriveTape. I'm thinking this can be achieved by either (1) implementing Tape as a graph instead of a tree, or (2) assigning each TapeOp a unique ID using a counter in EvalM. deriveTape would have to initially run through the Tape, putting each unique TapeOp in a lookup table, and counting the references to it. The Tape can then be derived starting from the output. Each time a reference to a TapeOp is encountered, the sensitivity, which is propagated to it, is added into a pool kept in the lookup table, and the number of references is decreased by one. When the number of references reaches zero, the Tape is derived. Thus, each Tape is derived only once.
  3. Make sure the error messages fit in.
  4. Perhaps move more responsibility from Interpreter.hs to AD.hs. I feel like the former uses a lot of functions from the latter, making the code unnecessarily complex to read.
  5. Perhaps use doOp for computations of ValuePrims. This would make the code for applying mathematical operations cleaner. Currently, it contains a lot of similar or duplicate code.

I have also littered the code with TODOs just ripe for the taking, and added a lot of explanatory text, as you mentioned that you would use this in your teaching. I have probably added too much, so feel free to delete it.

vox9 commented 1 month ago

It is probably worth mentioning that the old version of the code can be found here: interpreter-ad-old

athas commented 1 month ago

Will you fix the remaining style errors or shall I?

vox9 commented 1 month ago

Honestly, I'd love to, but I'm not entirely sure that I can, within a reasonable time frame. I'm not yet comfortable enough to feel that I understand the "Haskell" way of doing things, nor even the functional way - my ugly implementation of vjp2 and jvp2 are a great example of this. However, if you're up for having a chat about some of the details, I could probably shine it up pretty well. I'm also quite curious as to what you will be using the code for.

athas commented 1 month ago

You literally just have to run the ormolu formatter on the source code: https://github.com/diku-dk/futhark/blob/master/STYLE.md#ormolu

vox9 commented 1 month ago

Wow, doesn't get any easier than that ;) I'll give it a try right away

vox9 commented 1 month ago

Alright, I did it, but it changed a bunch of files which have nothing to do with this PR. I imagine you only want me to commit the changes to AD.hs, Values.hs, and Interpreter.hs?

vox9 commented 1 month ago

Oh, my bad, my tired eyes missed the If you find yourself working on such code, please reformat it while you are there. Committing now

athas commented 1 month ago

Thank you for the work. I have merged your implementation and created this issue to address the most significant remaining problem: https://github.com/diku-dk/futhark/issues/2187

I would certainly welcome further contributions, but the current implementation is operational.