FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.48k stars 606 forks source link

Issue taking gradients of Chains on GPU #1853

Closed jarbus closed 2 years ago

jarbus commented 2 years ago

I'm sure this is just a misunderstanding on my part, but I get an error when taking the gradient of a model with multiple layers on the GPU

using Flux
d1gpu = Chain(Dense(1,1, relu)) |> gpu
d2gpu = Chain(Dense(1,1, relu),Dense(1,1)) |> gpu
d2cpu = Chain(Dense(1,1, relu),Dense(1,1))

gradient(()->d1gpu([5]|>gpu)[1], params(d1))        # no problem
gradient(()->d2cpu([5])[1], params(d2cpu))           # no problem
gradient(()->d2gpu([5]|>gpu)[1], params(d2gpu))  # Yields the below error 
ERROR: GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CUDA.CuDeviceVector{Float32, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(>), Tuple{Base.Broadcast.Extruded{CUDA.CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Int64}}}}, Int64) failed
KernelError: passing and using non-bitstype argument

The only difference between d1gpu, which works, and d2gpu, which doesn't, is that d2gpu has two layers. What's the intended way of taking gradients when the model is on the gpu?

ToucheSir commented 2 years ago

What is d1? I don't see it defined anywhere, do you mean d1gpu?

When I run the third line, I get a scalar indexing warning. This probably comes from d2gpu(...)[1]. Switching to a GPU-compatible version:

gradient(()->d2gpu([5]|>gpu) |> sum, params(d2gpu))

Yields no errors.

Side note: you want to avoid calling cpu and gpu in a gradient context whenever possible. It creates additional work for the AD for no benefit. If it can happen outside of gradient, it probably should.

jarbus commented 2 years ago

Thank you for the response, and the sidenote, @ToucheSir. Yeah, I meant d1gpu originally -- I changed the variable name d1 to d1gpu and missed an instance.

Using your GPU-compatible version works, thanks! I didn't realize the issue was specifically with scalar indexing, since it works fine on the CPU. I also didn't realize the sum function could be used to get around the issue, it's a good workaround. However, if I really do need to get the gradient of a single element of a vector, how would I go about doing so without indexing? In my case, I'm doing reinforcement learning and need to get the gradient w.r.t a specific action probability.

I suppose I could multiply the output by a one-hot vector and then get the sum, but that seems more like a hack.

ToucheSir commented 2 years ago

GPU arrays and host-side indexing don't mix real well, so the hack may well be the fastest way to do things. I don't recall if multiplying by a one-hot array is optimized for GPU, but it may well be. Alternatively you could use NNlib.gather to grab only the indices/columns/rows you need, as it's GPU compatible and has a well-defined gradient.

jarbus commented 2 years ago

Perfect, that's exactly what I was looking for but didn't realize it existed, thank you so much!