Closed sangyx closed 4 years ago
For high-order differentials, we may need to traverse the calculation graph multiple times. At the same time, the calculation for each differential must be added to the calculation graph. This is more complicated. I need a period of time to consider this issue.
I think that, written correctly, the only thing to be done is to make y.diff[x]
itself a DiffArray
with its own calculation graph. Then we can diff it like any other array.
OK, I will try.
Maybe there's something to learn from the work on jets
in jax
?
The latest pr has not passed the test, while the code in my machine can pass the test. I think it may be a problem with the new version of unumpy.
Besides, do we need to change the API to not store diff? If we don't store diff, then some additional data needs to be stored when constructing the calculation graph, such as which nodes the node points to (otherwise we don't know whether the diff has been calculated), space is not saved much.
Besides, do we need to change the API to not store diff?
It can store the diff, but only in a cache that takes into account the base.
I think it may be a problem with the new version of unumpy.
Yes, it is. Just test locally, I'll try to fix up the errors.
Also, it seems to me that you're still storing the diff on _diff
. This should be a dict with the key being the base and the value being the diff, and it should act as a cache.
What does ”it should act as a cache“ mean? We can use the y.to(x)
to obtain the derivative or the higher-order derivative of y
to x
. I think it is no different from y.diff[x]
.
There are two motivations to this:
y.to(x)
is related to not just y
but also to x
. Storing it directly on y
is incorrect, it's similar to having global state (storing things globally that should be on objects), which is generally considered bad practice, and in threading applications it doesn't work nicely. In this case, it would break down if two threads calculated y.to(z)
and y.to(x)
together. The last one stored would be preserved, the other would be lost.y.to(x)
and y.to(z)
for another derivative, and both were computed previously, you have to re-compute one or the other. With the suggested approach, you can check if they're already present, and if so, just use them, which is good for performance.We store the diff of x
in x
, not in y
. And we have two flags _visit
and _visit_jacobian
to record the staring node(y
) to ensure that the calculation graph is only traversed once. So the y.to(x)
and y.to(z)
will get the correct result and will not be covered. And every time we call to
, it will check whether the calculation graph has been traversed(if x._visit != self.name
), so the diff of x
and z
will not be re-computed.
While I think there is a potential problem if we call y.to(x)
and z.to(x)
( z = fun(y) ) in two threads at the same time. The diff of x
will be covered. I will change the API to fix this problem.
Thank you for your patience.
Because we traverse the calculation graph from back to front to calculate the derivative, each node can get the information of the nodes behind it, but the information of the previous nodes is not very convenient. For example, we can get the info of v6
when calculating the diff of v4
.
So I want to cache the value of dy/dx
in the x.diff[y]
in the object x
. At the same time, we use y.to(x)
(return x.diff[y]) to get the derivative and the higher-order derivative. This way can solve the above-mentioned problems. Is this acceptable?
It’s perfect! Thanks!
I'm working to refactor unumpy
and the backend here so that this works. This PR is not blocked on your count, anymore.
I fixed unumpy
... It was a mind-bending issue involving metaclasses. I'll merge this PR once CI is green.
This is merged. Thanks for the patience, @sangyx, and thanks for the contribution!
The current code uses a derivation method similar to
pytorch
. This method of storing the gradient of each variable is more convenient for us to quickly optimize the parameters in the process of gradient descent. If not needed, we can modify the API to only return the derivative of one variable at a time.For high-order differentials, we may need to traverse the calculation graph multiple times. At the same time, the calculation for each differential must be added to the calculation graph. This is more complicated. I need a period of time to consider this issue.