Closed hxjz233 closed 2 months ago
Hi hxjz232!
I had a look a this, and it looks like this is indeed a mistake from my end, I am assuming some default argument being filled in somewhere in the rrules, but this does not work as soon as there is a backend specified that is not the default one (this is how @cutensor
functions, it implicitly inserts the cuTENSOR backend everywhere).
I think I fixed it and wrote some additional tests to prevent future failure, once the tests pass I should be able to merge this.
On a separate note, I noticed that this uses VectorInterface for some of the implementations, which by default falls back to a broadcasting operation, which is not necessarily what you want to do for CuArrays. I'll write a fix for that, and update you here once I finish it.
In any case, thanks for letting me now that this is broken, I hope to have it fixed asap, as this is definitely something that is wrong on our side of things.
https://github.com/Jutho/VectorInterface.jl/pull/14 should also get rid of the warning message for scalar indexing with CuArrays. Feel free to re-open an issue if things still are not working the way you expect!
Hi Ikdvos, just FYI, the given code won't pass and gives (in fact the same as before)
But if you switch to using Yota
and g(x) = grad(f, x)
it does its work, so there might be something to be checked on the Zygote
side. Nevertheless, there is a solution for AD+cuTENSOR
after all and that is already quite cool!
The changes in VectorInterface were not yet tagged, but this should be resolved once this is merged: https://github.com/JuliaRegistries/General/pull/105225 Would you mind trying again with version v0.4.5 of VectorInterface? I am hoping that fixes it.
Yes it solves the issue! Thanks for the effort! :)
I met difficulties implementing my code for tensor calculations on a GPU, and it basically amounts to the issue of backpropagating through tensor operations. Here is a simplified code.
The given code can run nicely if the target function had
@tensor
. Should I modify my code or wait for later updates? Or maybe having cuTENSOR working with back-propagation is in principle not possible to implement?