Closed pavanky closed 7 years ago
@botev @jramapuram @itsnarsi This has been a long time coming, but I'd appreciate if you guys had any feedback as well.
CC @arrayfire/core-devel
@Reithan too
Awesome work @pavanky . Will take a look in more detail when I get to a terminal. Quick question: can you take second derivatives with your implementation?
@jramapuram Not yet, I wanted to get the first order working first :)
@jramapuram went ahead and changed the gradients to be Variables too. This should make it easy to perform higher order derivatives.
@pavanky just tested it on my laptop and it looks pretty neat. Unlike python, I did not see any initial delay. This might be because of no JIT I guess. When will this be merged to this repo?
@itsnarsi This is still very nascent. I want to incorporate some of the stuff mentioned here to make it more efficient: http://pytorch.org/docs/master/notes/autograd.html#excluding-subgraphs
Decreased the scope of the PR to get a minimum viable thing going. The additional functions and operators can be added once this PR gets merged.
@jramapuram I think enabling the support for higher order derivatives by default will increase the memory being used. I am going to enable a flag to enable it during the backward pass. By default only the values will be stored.
What is done so far:
autograd::Variable
,autograd::backward
.Variable
af::array
from the uservar.backward(grad_var)
is invoked, it builds a DAG as vector starting with the current variable and propagates gradients down the graph to all the Variables in the graph using the grad function specified at each variable.var.setCalcGrad(false)
Functions
Variable
parameters and returnVariable
as a parameter.Variable
constructed using arguments as parameters:af::array
: The result calculated earliervector<Variable>
: containing the inputs to the functionBackwardFunction_t
: A function pointer to the backward pass. Usually implemented as a lambda function.Example function:
Example:
A simple example showcasing how this can be done currently
TODO: for this PR
[x] Add train and evaluation mode for modules