JuliaDiff / TaylorDiff.jl

Taylor-mode automatic differentiation for higher-order derivatives
https://juliadiff.org/TaylorDiff.jl/
MIT License
73 stars 8 forks source link

TaylorDiff derivative() does not support GPU types #42

Closed jacob-m-wilson-42 closed 1 year ago

jacob-m-wilson-42 commented 1 year ago

Hello everyone! I hope this is not a simple mistake on my part, but it appears as though TaylorDiff doesn't like GPU types. I haven't worked with cuarrays for a year or two, but it looks like CuArray types now have some additional information attached to them that doesn't make them look like a vanilla array. Please see the below MWE.

using Flux
using TaylorDiff

# works
testinput = rand(3)
derivative_direction = [1e0,0e0,0e0]
model = Dense(3=>1,sin)
model(testinput) # works
derivative(modelinput -> model(modelinput)[1], testinput, derivative_direction, 2) # works

# doesn't work
testinput_gpu = gpu(rand(3))
derivative_direction_gpu = gpu([1e0,0e0,0e0])
model_gpu = Dense(3=>1,sin)
model_gpu = gpu(model)
model_gpu(testinput_gpu) # works, output is still on the gpu
derivative(modelinput -> model(modelinput)[1], testinput_gpu, derivative_direction_gpu, 2) # breaks here

and the resulting error is

ERROR: MethodError: no method matching derivative(::var"#27#28", ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Int64)
Closest candidates are:
  derivative(::Any, ::Vector{T}, ::Vector{T}, ::Int64) where T<:Number at C:\Users\bachs\.julia\packages\TaylorDiff\zNnz2\src\derivative.jl:21
  derivative(::Any, ::T, ::Int64) where T<:Number at C:\Users\bachs\.julia\packages\TaylorDiff\zNnz2\src\derivative.jl:17
  derivative(::Any, ::T, ::Val{N}) where {T<:Number, N} at C:\Users\bachs\.julia\packages\TaylorDiff\zNnz2\src\derivative.jl:26
  ...

I think that the type of the CuArray is causing the issue since it has multiple fields.

typeof(testinput_gpu) # CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}

If anyone could help, I would greatly appreciate it! I'm very busy but I am willing to contribute to help fix this issue.

Thanks in advance!

tansongchen commented 1 year ago

I think this have been fixed at #44 ? Please reopen this issue if any problem still exists

jacob-m-wilson-42 commented 3 weeks ago

Yes this is obviously closed. I came across this old post, sorry for not responding.