Jutho / TensorOperations.jl

Julia package for tensor contractions and related operations
https://jutho.github.io/TensorOperations.jl/stable/
Other
438 stars 55 forks source link

cuTENSOR not working with automatic differentiation #167

Closed hxjz233 closed 2 months ago

hxjz233 commented 2 months ago

I met difficulties implementing my code for tensor calculations on a GPU, and it basically amounts to the issue of backpropagating through tensor operations. Here is a simplified code.

using TensorOperations
using ChainRulesCore, Zygote
using CUDA, cuTENSOR

function QuadMin(x)
    @cutensor res = x[i,j] * x[i,j]    # for demonstrating some tensor operations with explicit index order
    return res
end

function AD4CuArray()
    initval = ones(3, 3) * 1.0
    f(x) = QuadMin(x)
    g(x) = gradient( f, x )[ 1 ]
    println(g(initval))
    return nothing
end

AD4CuArray()

The given code can run nicely if the target function had @tensor. Should I modify my code or wait for later updates? Or maybe having cuTENSOR working with back-propagation is in principle not possible to implement?

lkdvos commented 2 months ago

Hi hxjz232!

I had a look a this, and it looks like this is indeed a mistake from my end, I am assuming some default argument being filled in somewhere in the rrules, but this does not work as soon as there is a backend specified that is not the default one (this is how @cutensor functions, it implicitly inserts the cuTENSOR backend everywhere). I think I fixed it and wrote some additional tests to prevent future failure, once the tests pass I should be able to merge this.

On a separate note, I noticed that this uses VectorInterface for some of the implementations, which by default falls back to a broadcasting operation, which is not necessarily what you want to do for CuArrays. I'll write a fix for that, and update you here once I finish it.

In any case, thanks for letting me now that this is broken, I hope to have it fixed asap, as this is definitely something that is wrong on our side of things.

lkdvos commented 2 months ago

https://github.com/Jutho/VectorInterface.jl/pull/14 should also get rid of the warning message for scalar indexing with CuArrays. Feel free to re-open an issue if things still are not working the way you expect!

hxjz233 commented 2 months ago

Hi Ikdvos, just FYI, the given code won't pass and gives (in fact the same as before)

Error Message ``` [1] error(s::String) @ Base .\error.jl:35 [2] assertscalar(op::String) @ GPUArraysCore D:\Julia\depot\packages\GPUArraysCore\uOYfN\src\GPUArraysCore.jl:103 [3] getindex @ D:\Julia\depot\packages\GPUArrays\dAUOE\src\host\indexing.jl:48 [inlined] [4] scalar_getindex @ D:\Julia\depot\packages\GPUArrays\dAUOE\src\host\indexing.jl:34 [inlined] [5] _getindex @ D:\Julia\depot\packages\GPUArrays\dAUOE\src\host\indexing.jl:17 [inlined] [6] getindex @ D:\Julia\depot\packages\GPUArrays\dAUOE\src\host\indexing.jl:15 [inlined] [7] scale(x::CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}, α::VectorInterface.Zero) @ VectorInterface D:\Julia\depot\packages\VectorInterface\TAlcJ\src\abstractarray.jl:39 [8] #61 @ D:\Julia\depot\packages\TensorOperations\dNaBM\ext\TensorOperationsChainRulesCoreExt.jl:93 [inlined] [9] unthunk @ D:\Julia\depot\packages\ChainRulesCore\zgT0R\src\tangent_types\thunks.jl:204 [inlined] [10] wrap_chainrules_output @ D:\Julia\depot\packages\Zygote\jxHJc\src\compiler\chainrules.jl:110 [inlined] [11] map (repeats 2 times) @ .\tuple.jl:276 [inlined] [12] wrap_chainrules_output @ D:\Julia\depot\packages\Zygote\jxHJc\src\compiler\chainrules.jl:111 [inlined] [13] ZBack @ D:\Julia\depot\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined] [14] Pullback @ D:\MagBEC\juliatest\t_adjulia\cuTensorAD.jl:6 [inlined] [15] (::Zygote.Pullback{Tuple{typeof(QuadMin), Matrix{Float64}}, Tuple{Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensorfree!_pullback#47"{Tuple{CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}, TensorOperations.Backend{:cuTENSOR}}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensorscalar_pullback#49"{CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{}, Tuple{}}, Matrix{Float64}, Tuple{Tuple{}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, Bool, TensorOperations.Backend{:cuTENSOR}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#67"{CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}, Tuple{Tuple{}, Tuple{}}, Matrix{Float64}, Tuple{Tuple{}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{TensorOperations.Backend{:cuTENSOR}}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}})(Δ::Float64) @ Zygote D:\Julia\depot\packages\Zygote\jxHJc\src\compiler\interface2.jl:0 [16] Pullback @ D:\MagBEC\juliatest\t_adjulia\cuTensorAD.jl:12 [inlined] [17] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#f#3", Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(QuadMin), Matrix{Float64}}, Tuple{Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensorfree!_pullback#47"{Tuple{CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}, TensorOperations.Backend{:cuTENSOR}}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensorscalar_pullback#49"{CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{}, Tuple{}}, Matrix{Float64}, Tuple{Tuple{}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, Bool, TensorOperations.Backend{:cuTENSOR}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#67"{CuArray{Float64, 0, CUDA.Mem.DeviceBuffer}, Tuple{Tuple{}, Tuple{}}, Matrix{Float64}, Tuple{Tuple{}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{TensorOperations.Backend{:cuTENSOR}}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}}}})(Δ::Float64) @ Zygote D:\Julia\depot\packages\Zygote\jxHJc\src\compiler\interface.jl:91 [18] gradient(f::Function, args::Matrix{Float64}) @ Zygote D:\Julia\depot\packages\Zygote\jxHJc\src\compiler\interface.jl:148 [19] g @ D:\MagBEC\juliatest\t_adjulia\cuTensorAD.jl:13 [inlined] [20] AD4CuArray() @ Main D:\MagBEC\juliatest\t_adjulia\cuTensorAD.jl:14 [21] top-level scope @ D:\MagBEC\juliatest\t_adjulia\cuTensorAD.jl:18 ```
Version Info ``` julia> Pkg.status(["TensorOperations","ChainRulesCore","Zygote","Yota","CUDA","cuTENSOR"]) Status `D:\Julia\depot\environments\v1.9\Project.toml` ⌅ [052768ef] CUDA v5.1.2 [d360d2e6] ChainRulesCore v1.23.0 [6aa20fa7] TensorOperations v4.1.1 [cd998857] Yota v0.8.5 [e88e6eb3] Zygote v0.6.69 ⌃ [011b41b2] cuTENSOR v1.2.1 ```

But if you switch to using Yota and g(x) = grad(f, x) it does its work, so there might be something to be checked on the Zygote side. Nevertheless, there is a solution for AD+cuTENSOR after all and that is already quite cool!

lkdvos commented 2 months ago

The changes in VectorInterface were not yet tagged, but this should be resolved once this is merged: https://github.com/JuliaRegistries/General/pull/105225 Would you mind trying again with version v0.4.5 of VectorInterface? I am hoping that fixes it.

hxjz233 commented 2 months ago

Yes it solves the issue! Thanks for the effort! :)