FluxML / NNlib.jl

Neural Network primitives with multiple backends
Other
201 stars 121 forks source link

add Sparsemax activation #354

Open tylerjthomas9 opened 2 years ago

tylerjthomas9 commented 2 years ago

Source paper: http://arxiv.org/abs/1602.02068

PyTorch implementation: https://github.com/Qwicen/node/blob/master/lib/nn_utils.py

I started working on implementing sparsemaxin Julia for TabNet. I thought that it would best fit in NNlib.jl. It should have the exact same functionality as softmax.

darsnack commented 2 years ago

If you have an implementation, then a PR would be welcome! We can iterate the design there.

tylerjthomas9 commented 2 years ago

Still trying to get the jacobian to work, but I have the initial forward pass

using NNlib
using LinearAlgebra
using Zygote

sparsemax(x; dims = 1) = sparsemax!(similar(x, (float ∘ eltype)(x)), x; dims = dims)

sparsemax!(x; dims = 1) = sparsemax!(x, x; dims = dims)

function sparsemax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
     # only 2D tensors are supported
     @assert dims in (1, 2)

     max_ = maximum(x; dims=dims)
     x .-= max_

     # make ix like
     d = size(x, dims)
     if dims==1
         rhos = reshape(collect(1:d), d, 1) |> typeof(x)
     elseif dims == 2 
         rhos = reshape(collect(1:d), 1, d) |> typeof(x)
     end

     # compute threshold and support
     x_sorted = sort(x; dims=dims, rev=true)
     x_cumsum = cumsum(x_sorted; dims=dims) .- 1.0
     support =  rhos .* x_sorted .> x_cumsum
     support_size = vec(sum(support; dims=dims)) |> Vector{Int64}
     if dims == 1
         tau = diag(NNlib.gather(transpose(x_cumsum), support_size))
     elseif dims == 2
         tau = diag(NNlib.gather(x_cumsum, support_size))
     end
     tau ./= support_size

     if dims == 1
         out = clamp.(x .- transpose(tau), 0, Inf)
     elseif dims == 2
         out =  clamp.(x .- tau, 0, Inf)
     end
end

x = [0.3367 -0.1863; 0.1288 2.2082; 0.2345 -0.638; 0.2303 0.4617; -1.1229 0.2674] 
println("Sparsemax probabilities")
sparsemax(x; dims=1)
Sparsemax probabilities

2×5 Matrix{Float64}:
 0.7615  0.0  0.93625  0.3843  0.0
 0.2385  1.0  0.06375  0.6157  1.0