Closed sangyx closed 4 years ago
Ah, yes. I remember running into this when I was initially working with udiff
. Basically, in your approach, it would work as follows:
DiffArray.to
method should only compute the diff with respect to what's passed in, not anything else.DiffArray.backward
should only compute the calculation graph, but only for "the current level". The next level should remain "unevaluated" or in the form of a graph until you actually call DiffArray.diff[some_name].backward
.I have fixed the problem, we have to put all the calculations that need to register gradients in each diff function in lambda g
(or vjp
), so as to avoid infinite loops.
The error of <uarray multimethod '__instancecheck__'>
in ci seem to be related to uarray
.
I'll fix those up. Don't worry about them for now.
@peterbell10 I can reproduce the segfault in CI. Is it possible for you to investigate? I was also getting a bunch of "returned NULL without setting an error" on determine_backend
.
@sangyx Passing the torch back to you. 😄 The segfaults should be resolved now.
Edit: To resolve them locally, get master
of uarray
and unumpy
, and do pip install -e .
in the directories of both.
Thanks!
@hameerabbasi Hi, I want to add <uarray multimethod '__instancecheck__'>
to raw_functions
. But I don't know how to refer it. How can I find it by code like np.ndim
?
Hi, I want to add
<uarray multimethod '__instancecheck__'>
toraw_functions
. But I don't know how to refer it. How can I find it by code likenp.ndim
?
You can probably do something like unumpy.ClassOverrideMeta.__instancecheck__
.
@hameerabbasi Could you help me to check whether the NoGradBackend
is correct?
Thanks @sangyx!
I have added the diff function of
np.stack
.When we extend the code to higher-order derivatives, the situation becomes a bit more complicated. There are some places to improve:
np.sum
is used when registering the gradient ofnp.sum
. I plan to add depth attributes toDiffArray
orto()
to control the order of gradients.ufunc.__call__
.Do you have any idea about the above questions?