zenna / OmegaCore.jl

1 stars 2 forks source link

Traits #6

Closed zenna closed 4 years ago

zenna commented 4 years ago

Traits are possible in Julia using functions on types and multiple dispatch.

The basic setup is as follows.

Suppose we have a function f on which we want to apply different methods depending on whether it is indexable or not.

What we would like to write is something like

f(x::T) where {isindexable(T)} = x[1]
f(x) = first(x)

Unfortunately this is not valid Julia.

The work around in Julia is something like this:

# First create a type to represent the information on whether another type is indexable or not
struct IsIndexable end
struct IsNotIndexable end

# Then we create "trait functions" which map types to thees traits
trait_function_is_indexable(T) = IsNotIndexable()
trait_function_is_indexable(::Array) = IsIndexable()

# Finally we alter the definition of f.
# The first method dispatched based on type 

f(x::T) where T = f(trait_function_is_indexable(T), x)
f(::IsIndexable, x) = x[1]
f(::IsNotIndexable, x) = first(x)

It's a bit verbose but this is nice. It supplements for the fact that julia does not have multiple abstract inheritance. If it did, we might imagien Array <: IsIndexable. But even then that's not ideal; traits allow us to add more information about types at any point. If we had multiple abstract inheritance we would need to decide all the relevant traits upfront.

So what's wrong with this?

The limitation is that the trait dispatch function f(x::T) where T = f(trait_function_is_indexable(T), x) has to determine what the relevant traits are up front. We would like to write something instead like

f(x::T) where T = f(traits(T), x)
f(::IsIndexable, x) = x[1]
f(::IsNotIndexable, x) = first(x)

Could this possibly work?

traits would have to produce something which is a super type of IsIndexable and IsNotIndexable

abstract type Traits{T} end
struct IsIndexable <: Traits{Vector} end
#struct IsNotIndexable <: Traits end

traits(::Type{T}) where T = Union{subtypes(Traits{T})...}

f(x::T) where T = f(traits(T), x)
f(::Type{IsIndexable}, x) = @show x[1]
f(_, x) = @show first(x)

Issues with this:

julia> f([1,2,3])
first(x) = 1
1

julia> f([1,2,3])
first(x) = 1
1

We wanted f([1,2,3]) to use x[1] but it won't because IsIndexable is a subtype of Traits{Vector} not Traits{Vector{Int}}. Of course we can remedy this as follows:

abstract type Traits{T} end
struct IsIndexable <: Traits{Vector{Int}} end

traits(::Type{T}) where T = Union{subtypes(Traits{T})...}

f(x::T) where T = f(traits(T), x)
f(::Type{IsIndexable}, x) = @show x[1]
f(_, x) = @show first(x)

julia> f((x for x = 1:10))
first(x) = 1
1

julia> f([1,2,3])
x[1] = 1
1

Or

abstract type Traits{T} end
struct IsIndexable{T} <: Traits{Vector{T}} end

traits(::Type{T}) where T = Union{subtypes(Traits{T})...}

f(x::T) where T = f(traits(T), x)
f(::Type{<:IsIndexable}, x) = @show x[1]
f(_, x) = @show first(x)

julia> f((x for x = 1:10))
first(x) = 1
1

julia> f([1,2,3])
x[1] = 1
1

Actually this is completely wrong. What was I thinking?

zenna commented 4 years ago

What about?

struct TraitConjunct{T} end

traits(::Type{Vector}) = Traits((:isindexable, :somethingelse))
f(x::T) where T = f(traits(T), x)
f(::HasTraits{(:isindexable}), x) = @show x[1]
f(_, x) = @show first(x)

This doesn't work because HasTtrait(..) is not a subtype of Traits(...)

What about?

abstract type IsIndexableAndSomethingElse end

struct IsIndexable <:IsIndexableAndSomethingElse end
struct IsSomethingElse <: IsIndexableAndSomethingElse end

struct NoTraits end

traits(::Type) = NoTraits
traits(::Type{<:Vector}) = IsIndexableAndSomethingElse

f(x::T) where T = f(traits(T), x)
f(::Type{T}, x) where {T >: IsIndexable} = @show x[1]
f(_, x) = @show first(x)

f((x for x = 1:10))
f([1,2,3])
zenna commented 4 years ago

A poor man's solution that might work well in the case of a finite and small number of tags

# Make a type for every possible combination
abstract type HasTagErr end
abstract type HasTagErrAndLogPdf end
abstract type HasTagErrAndLogPdfAndMem end

We have around 5 tags -- err, logpdf, mem, rng -- so that's 32 types.

Then for each abstract type make a union for all each one that has it

const HaLogPdf = Union{HasTagErrAndLogPdfAndMem, HasTagErrAndLogPdf, ...}

traits(ω) => HasTagErrAndLogPdfAndMem
zenna commented 4 years ago

That's not going to work.

Last try: Unions!


struct Err end
struct LogPdf end
struct Mem end
struct Intervene end

const alltags = Union{Err, LogPdf, Mem, Intervene}

traits(::Type{<:Vector}) = Union{LogPdf, Err}

trait(x) = Type{T} where {T >: x}

f(x::T) where T = f(traits(T), x)
f(::trait(LogPdf), x) = 1
f(::trait(Union{LogPdf, Err}), x) = 2
zenna commented 4 years ago

This kind of works. We would need traits to create the appropriate union from the named tuple but that shouldnt be problematic.

We're not really making these traits incrementally, and that's another shortcoming, but fine for the case of Omega.

The biggest problem is that specificity is inverted. We'd hope it to be the case that if I define a method for f(x::T) where T has trait A and B, then that would be more specific than the just A. Actually it's fine I think

zenna commented 4 years ago

This fails type inference though:

struct Err end
struct LogPdf end
struct Mem end
struct Intervene end

const alltags = Union{Err, LogPdf, Mem, Intervene}

traits(::Type{<:Vector}) = Union{LogPdf, Err}
traits(::Type{T}) where T = Union{Bool}

trait(x) = Type{T} where {T >: x}

f(x::T) where T = f(traits(T), x)
    f(::trait(Bool), x) = 1.0
f(::trait(Union{LogPdf, Err}), x) = 2

julia> @code_warntype f(1)
Variables
  #self#::Core.Compiler.Const(f, false)
  x::Int64

Body::Union{Float64, Int64}
1 ─ %1 = Main.traits($(Expr(:static_parameter, 1)))::Type
│   %2 = Main.f(%1, x)::Union{Float64, Int64}
└──      return %2