JuliaGPU / Adapt.jl

Other
86 stars 24 forks source link

Support for arbitrary callables #84

Open termi-official opened 2 weeks ago

termi-official commented 2 weeks ago

adapt_structure fails if there are mixed concrete types and parametric types in a function type.

Tests:

using Adapt

# This should be equivalent to what adapt_structure does
g(f::T) where T = length(T.parameters)

abstract type AbstractSuperType{t} end

abstract type AbstractSuperType2{t} <: Function end

struct A{T}
    a::T
end
a = A(1)
@assert g(a) == 1 # Works
Adapt.adapt_structure(1.0, a) # Works

struct B{T} <: AbstractSuperType2{true}
    b::T
end
b = B(1)
@assert g(b) == 1 # Works
Adapt.adapt_structure(1.0, b) # works

struct B2 <: AbstractSuperType2{true}
    b::Int
end
b2 = B2(1)
@assert g(b2) == 0 # Works
Adapt.adapt_structure(1.0, b2) # works

struct C{T} <: AbstractSuperType2{true}
    b::T
    a::Int
end
c = C(1,1)
@assert g(c) == 1 # Works
Adapt.adapt_structure(1.0, c) # fails

struct C2{T}
    b::T
    a::Int
end
c2 = C2(1,1)
@assert g(c2) == 1 # Works
Adapt.adapt_structure(1.0, c2) # works

struct C3{T} <: AbstractSuperType{true}
    b::T
    a::Int
end
c3 = C3(1,1)
@assert g(c3) == 1 # Works
Adapt.adapt_structure(1.0, c3) # works

struct D{T} <: AbstractSuperType2{true}
    a::Int
    b::T
end
d = D(1,1)
@assert g(d) == 1 # Works
Adapt.adapt_structure(1.0, d) # fails

struct D2{T}
    a::Int
    b::T
end
d2 = D2(1,1)
@assert g(d2) == 1 # Works
Adapt.adapt_structure(1.0, d2) # works

struct D3{T} <: AbstractSuperType{true}
    a::Int
    b::T
end
d3 = D3(1,1)
@assert g(d3) == 1 # Works
Adapt.adapt_structure(1.0, d3) # works

struct D4{T} <: AbstractSuperType2{true}
    a::Int
    b::T
end
d4 = D4(1,1)
@assert g(d4) == 1 # Works
Adapt.adapt_structure(1.0, d4) # fails

struct E{T} <: AbstractSuperType2{true}
    a::Int
    b::T
    c::Int
end
e = E(1,1,1)
@assert g(e) == 1 # Works
Adapt.adapt_structure(1.0, e) # fails
maleadt commented 2 weeks ago

This is a known limitation; Adapt currently only works with callable objects (functions, closures) as emitted by the Julia front-end.

Why are you even declaring AbstractSuperType2{t} <: Function in the first place? Can you show an example where this fails with normal functions? By doing so, you're triggering the predefined adapt rule that works with Julia-defined functions, which your types aren't. I'd suggest not defining <: Function and writing your own adapt convertors.

termi-official commented 2 weeks ago

Why are you even declaring AbstractSuperType2{t} <: Function in the first place?

I am working on a package downstream of SciMLBase where I have some objects subtyping from https://github.com/SciML/SciMLBase.jl/blob/e3a0de8451d7a924807975de892673404c5b8d9a/src/SciMLBase.jl#L578 to represent my PDE discretization and I simply wanted to have a reproducer for the issue with similar structure to what I do (and to show that the issue is not subtyping per-se but subtyping from Function.

Can you show an example where this fails with normal functions? [...]

I was not able to find reproduce the issue for normal functions.

This is a known limitation; Adapt currently only works with callable objects (functions, closures) as emitted by the Julia front-end.

Please note that Adapt works as expected in the cases when

  1. My struct has 0 type parameters and hard-coded types (e.g. B2).
  2. All of the field types of the struct are determined by type parameters (e.g. case B).

I am currently simply use 2 to work around this issue. I just wanted to report this since I spend a bit of time to track this down the exact problem, so others don't have to.

maleadt commented 2 weeks ago

I am working on a package downstream of SciMLBase where I have some objects subtyping from https://github.com/SciML/SciMLBase.jl/blob/e3a0de8451d7a924807975de892673404c5b8d9a/src/SciMLBase.jl#L578 to represent my PDE discretization and I simply wanted to have a reproducer for the issue with similar structure to what I do (and to show that the issue is not subtyping per-se but subtyping from Function.

@ChrisRackauckas What's the reason for the <: Function there?

ChrisRackauckas commented 2 weeks ago

It's a function, it has f(u,p,t) calls.

maleadt commented 2 weeks ago

It's a function, it has f(u,p,t) calls.

All objects are callable when you add methods to its type. Why does it need the <: Function?

oscardssmith commented 2 weeks ago

It's pretty common for people to annotate callable structs as <:Function