trthatcher / MLKernels.jl

Machine learning kernels in Julia.
http://trthatcher.github.io/MLKernels.jl/dev/
MIT License
78 stars 37 forks source link

Implement Automatic Relevance Determination #83

Open trthatcher opened 5 years ago

trthatcher commented 5 years ago

Implement Automatic Relevance Determination as described in Rasmussen (page 2 of pdf, 106 of text):

http://www.gaussianprocess.org/gpml/chapters/RW5.pdf

trthatcher commented 5 years ago

@theogf Thank you for your ideas in GH-82! I like having a scaling variable field of type Union{Real,AbstractVector{<:Real}}.

I want to propose a couple changes. Particularly:

This would require a modification to the BaseFunction types. Outline of approach:

Here is a minimal working example for the RationalQuadraticKernel:

# Base Functions
abstract type BaseFunction{T<:Real} end

abstract type WeightedBaseFunction{T} <: BaseFunction{T} end

struct SquaredEuclidean{T<:Real} <: BaseFunction{T} 
    a::T
end

struct WeightedSquaredEuclidean{T<:Real} <: WeightedBaseFunction{T} 
    w::Vector{T}
end

# Will be used by the Kernel types to construct appropriate version of Base Function
base_sqdist(a::T) where {T} = SquaredEuclidean{T}(a)
base_sqdist(w::AbstractVector{T}) where {T} = WeightedSquaredEuclidean{T}(w)

# Will be used in unsafe_base_evaluate to retrieve weight and pass through
get_scale_factor(f::WeightedSquaredEuclidean) = f.w

# Base Function Rules
@inline base_initiate(::BaseFunction{T}) where {T} = zero(T)
@inline base_return(::BaseFunction{T}, s::T) where {T} = s

@inline base_aggregate(::SquaredEuclidean{T}, s::T, x::T, y::T) where {T} = s + (x-y)^2
@inline base_aggregate(::WeightedSquaredEuclidean{T}, s::T, x::T, y::T, w::T) where {T} = s + w*(x-y)^2

@inline base_return(f::SquaredEuclidean{T}, s::T) where {T} = f.a*s

# Base Evaluation Changes
function unsafe_base_evaluate(
        f::BaseFunction{T},
        x::AbstractArray{T},
        y::AbstractArray{T}
    ) where {T<:Real}
    println("Running unweighted unsafe_base_evaluate")
    s = base_initiate(f)
    @simd for I in eachindex(x, y)
        @inbounds xi = x[I]
        @inbounds yi = y[I]
        s = base_aggregate(f, s, xi, yi)
    end
    base_return(f, s)
end

# These are new:
function unsafe_base_evaluate(
        f::BaseFunction{T},
        x::AbstractArray{T},
        y::AbstractArray{T},
        w::AbstractArray{T}
    ) where {T<:Real}
    println("Running weighted unsafe_base_evaluate")
    s = base_initiate(f)
    @simd for I in eachindex(x, y, w)
        @inbounds xi = x[I]
        @inbounds yi = y[I]
        @inbounds wi = w[I]
        s = base_aggregate(f, s, xi, yi, wi)
    end
    base_return(f, s)
end

@inline function unsafe_base_evaluate(
        f::WeightedBaseFunction{T},
        x::AbstractArray{T},
        y::AbstractArray{T}
    ) where {T<:Real}
    unsafe_base_evaluate(f, x, y, get_scale_factor(f))
end

# Kernels could be defined as:
const Scale{T} = Union{AbstractVector{T},T}

abstract type Kernel{T<:Real} end

struct RationalQuadraticKernel{T<:Real} <: Kernel{T}
    α::Scale{T}
    β::T
end

@inline function kappa(κ::RationalQuadraticKernel{T}, d²::T) where {T}
    return (one(T) + d²)^(-κ.β)
end

@inline basefunction(κ::RationalQuadraticKernel) = base_sqdist(κ.α)

# Demonstration

k1 = RationalQuadraticKernel{Float64}(2.0,1.0)
k2 = RationalQuadraticKernel{Float64}([2.0, 2.0, 2.0],1.0)
k3 = RationalQuadraticKernel{Float64}([1.0, 2.0, 3.0],1.0)

b1 = basefunction(k1)
b2 = basefunction(k2)
b3 = basefunction(k3)

x = rand(3)
y = rand(3)

unsafe_base_evaluate(b1, x, y)
unsafe_base_evaluate(b2, x, y)
unsafe_base_evaluate(b3, x, y)

What are your thoughts on this approach?