Closed LeonardJohard closed 3 years ago
Could you check against latest version of Flux? Also it seems like you're trying to move a scalar to vram, which would be a no-op since scalars are equivalent on cpu and gpu
Could you check against latest version of Flux? Also it seems like you're trying to move a scalar to vram, which would be a no-op since scalars are equivalent on cpu and gpu
Tested using Flux v0.12.7 as well as latest Zygote
It is not a scalar but a 1-d array for simplicity, but same.
The gradient is a struct with a 1-d Float32 array.
I think Dhairya means the output of sum(x.x)
is a scalar. Can you try with the following modification?
structfunc(x) = sum(x.x)
Yes sorry, that was some debug code I left. Same results with just sum.
The gradient is a struct with a 1-d Float32 array.
Can you post exactly what you see, i.e. the outputs, with their types?
And check that it works on a fresh session? Zygote doesn't export a gpu
function. Flux does, but you didn't load it here, and this doesn't look like it should rely on details of that.
I took the liberty of editing the top post to use cu
instead of gpu
:
julia> structfunc'(datatest)
(x = Float32[1.0],)
julia> gradient(structfunc, datatest)
((x = Float32[1.0],),)
But a NamedTuple containing an Array or a CuArray print the same. For me the code above correctly gives a CuArray, but if it doesn't for some, that is obviously a problem.
julia> (; x = cu([1f0]))
(x = Float32[1.0],)
julia> structfunc'(datatest)
(x = Float32[1.0],)
julia> ans[1]
1-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
1.0
(@v1.7) pkg> st CUDA Zygote
Status `~/.julia/environments/v1.7/Project.toml`
[052768ef] CUDA v3.4.2
[e88e6eb3] Zygote v0.6.25
Ah right, I always forget about the compact display. Anyhow. I just tested with Zygote 0.6.2 and CUDA 2.6.3 (both >= 6months old), and get the same results (.x
is a CuArray
when using a custom struct), so it's fair to say things seem to be working as expected.
Are we good here?
Yes seems we are good, thanks!
As you can see, the adjoints of cuarrays is cpu arrays when placed inside structs in the current version. This is probably not desired behaviour? In earlier versions ~6 months ago this was not the case.