FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.54k stars 610 forks source link

Random Fourier Features #2207

Open bicycle1885 opened 1 year ago

bicycle1885 commented 1 year ago

Motivation and description

I'd like to propose adding a layer to Flux.jl: random Fourier features. The layer is basically a mapping from low-dimensional features (e.g., 1D times, 3D coordinates, etc.) to high-dimensional embeddings. This type of mapping is known to make it easier to make neural networks learn high-frequency, low-dimensional features. A good summary and the original paper of it can be found at: https://bmild.github.io/fourfeat/.

The layer has no tunable parameters. The parameters of it are initialized with random samples from a gaussian distribution and fixed after that. I have no idea about what will happen if we make it trainable.

If you like the idea, I'm happy to make a pull request.

Possible Implementation

This is a prototype of my proposal. The core of it can be implemented with less than 20 lines of code.

using Flux: @functor
using Optimisers: Optimisers

# https://bmild.github.io/fourfeat/
struct RandomFourierFeatures{A <: AbstractArray}
    in::Int
    out::Int
    W::A
end

function RandomFourierFeatures((in, out), σ)
    W = randn(typeof(σ), cld(out, 2), in) .* oftype(σ, 2π) * σ
    return RandomFourierFeatures(in, out, W)
end

@functor RandomFourierFeatures
Optimisers.trainable(::RandomFourierFeatures) = (;)  # no trainable parameters

function (rff::RandomFourierFeatures)(x::AbstractVecOrMat)
    Wx = rff.W * x
    return [cos.(Wx); selectdim(sin.(Wx), 1, 1:size(Wx, 1)-isodd(rff.out))]
end

RandomFourierEmbedding might be a better name.

ToucheSir commented 1 year ago

Thanks for the interesting reference. I'm not sure this passes the threshold of a sufficiently "mainstream" operation to be included in Flux core, but a separate package sounds like a good idea. If it turns out this becomes quite popular, we can always re-evaluate then :)

CarloLucibello commented 1 year ago

is something like that used in neural radiance fields?

bicycle1885 commented 1 year ago

Ja, it is not exactly the same but highly related. I think the proposed one is a spin-off of the NeRF paper. The paper is proposing a mapping with fixed weights like below:

$$\gamma(p) = (\sin(2^0 \pi p), \cos(2^0 \pi p), \dots, \sin(2^{L-1} \pi p), \cos(2^{L-1} \pi p))$$

The same idea is also used in the Transformer paper to encode token positions:

$$PE(pos, 2i) = \sin(pos / 10000^{2i/d}), PE(pos, 2i+1) = \cos(pos / 10000^{2i/d})$$

So, this kind of idea is used in many fields and enough "mainstream"-ish. Perhaps we can generalize the weighting scheme so that it includes all of them.

CarloLucibello commented 1 year ago

conflating all these methods into a single layer is too hard to document, also they would take different arguments at construction time, but having a separate layer for each of them seems a worthwhile contribution

bicycle1885 commented 1 year ago

I mean something like this:

using Flux: @functor
using Optimisers: Optimisers

struct FourierEmbedding{A <: AbstractMatrix}
    in::Int
    out::Int
    W::A
end

function FourierEmbedding((in, out), weights)
    return FourierEmbedding(in, out, weights(in, out))
end

@functor FourierEmbedding
Optimisers.trainable(::FourierEmbedding) = (;)  # no trainable parameters

function (f::FourierEmbedding)(x::AbstractVecOrMat)
    Wx = f.W * x
    E = [cos.(Wx); sin.(Wx)]
    return selectdim(E, 1, 1:f.out)
end

function positional_weights(base = 10_000)
    function (in, out)
        W = zeros(Float32, cld(out, 2), in)
        for i in axes(W, 1)
            W[i,:] .= inv(base^((i - 1) / out))
        end
        return W
    end
end

function power2_weights(offset = 0)
    function (in, out)
        W = zeros(Float32, cld(out, 2), in)
        for i in axes(W, 1)
            W[i,:] .= 2.0f0^(i - offset - 1) * Float32(π)
        end
        return W
    end
end

function gaussian_weights(σ)
    function (in, out)
        return randn(Float32, cld(out, 2), in) .* Float32(2π) * σ
    end
end