Closed DrChainsaw closed 4 years ago
I think its because Zygote uses Float64
by default for all gradient calculations. Change the gradient pass to
julia> Flux.Zygote.gradient(loss, ones(Float64,2,2,1,1), ones(Float64,2,2,1,1)) ([0.02663578873715052 0.02663578873715052; 0.02663578873715052 0.02663578873715052], [-0.025350493677740227 -0.025350493677740227; -0.025350493677740227 -0.025350493677740227])
I had a similar issue in leakyrelu
(https://github.com/FluxML/Flux.jl/issues/963). You can work with Float32
if you use a modified selu
function.
function myselu(x::Real)
λ = oftype(x/1, 1.0507009873554804934193349852946)
α = oftype(x/1, 1.6732632423543772848170429916717)
λ * ifelse(x > 0, x/one(x), α * (exp(x) - one(x)))
end
loss2(x,y) = Flux.mse(Chain(MaxPool((1,1)), z -> myselu.(z))(x), y)
julia> Flux.Zygote.gradient(loss2, ones(Float32,2,2,1,1), ones(Float32,2,2,1,1))
(Float32[0.026635807 0.026635807; 0.026635807 0.026635807], Float32[-0.025350511 -0.025350511; -0.025350511 -0.025350511])
It replaces 1
in the original selu
(https://github.com/FluxML/NNlib.jl/blob/30e61ef6233a32cea5005b8738fcd110ab7a363d/src/activation.jl#L99) by one(x)
.
I have created a PR for NNlib to keep type in gradient calculations (https://github.com/FluxML/NNlib.jl/pull/149). I hope it will also fix this issue.
If https://github.com/FluxML/NNlib.jl/pull/149 fixed the issue this can be closed
Sorry for fuzzy issue, but I'm a bit uncertain what is desired behaviour in this case.
Issue I'm seeing is this:
Error is that first argument to ∇maxpool is Array{Float64,4} while the second and third is ::Array{Float32,4} but the signature require that they have the same type.
This seems to be caused by the pullback for the activation function (selu in this case) is changing the type from Float32 to Float64:
The same issue also exists for elu and possibly other activation functions. If the anwer is to define adjoints for them I can do that, but I'm thinking that a more generic solution is perhaps wanted.