JuliaIO / JLD2.jl

HDF5-compatible file format in pure Julia
Other
546 stars 85 forks source link

Custom Serialization for paramteric types #418

Closed HumpyBlumpy closed 1 year ago

HumpyBlumpy commented 1 year ago

Hello,

I am trying to make use of the custom serialization feature to save on memory when saving large structs with several fields. Basically, I am only interested in part of the data, so I define a new type that only holds it, as well as the writeas and convert methods as indicated in the documentation. However, the struct does not get converted automatically but is saved instead in the original type.

Here is a minimal example illustrating my problem (my actual use case only differs in that it uses other custom types)

using JLD2

cd(@__DIR__)

struct Point{A<:Number,B<:Real}
    x::A
    y::B
end

struct PointSerial{A<:Number}
    x::A
end

JLD2.writeas(::Type{Point}) = PointSerial
Base.convert(::Type{PointSerial}, b::Point) = PointSerial(b.x)
#Base.convert(::Type{Point}, x::PointSerial) = Point(x.x,0)

thing = Point(5,6.4)

@save "test.jld2" thing
test2 = jldopen("test.jld2")

println(typeof(thing)) #Point{Int64, Float64}
println(typeof(test2["thing"])) #Point{Int64, Float64} should be PointSerial{Int64}
thing_converted = convert(PointSerial,thing) #PointSerial{Int64}

Am I doing something wrong?

JonasIsensee commented 1 year ago

You need to define the correct method signatures:

JLD2.writeas(::Type{<:Point{A}}) where A = PointSerial{A}
Base.convert(::Type{PointSerial{A}}, b::Point{A}) where A = PointSerial(b.x)

note, that you need to also define the correct back-conversion method:

Base.convert(::Type{<:Point{A}}, b::PointSerial{A}) where A = Point(b.x,0.)
HumpyBlumpy commented 1 year ago

Thanks @JonasIsensee ! I am trying to understand this. It seems to work in conjuncture with JLD2. However, it fails when simply calling convert

one =  Point(1,0.0)
one_converted= convert(PointSerial,one ) 
ERROR: MethodError: Cannot `convert` an object of type 
  Point{Int64, Float64} to an object of type 
  PointSerial
Closest candidates are:
  convert(::Type{T}, ::T) where T at Base.jl:61

why is that?

JonasIsensee commented 1 year ago

This is due to how sub-typing works in julia : https://docs.julialang.org/en/v1/manual/types/#Parametric-Types it works if you define Base.convert(::Type{<:PointSerial}, b::Point) = PointSerial(b.x)

HumpyBlumpy commented 1 year ago

Thanks but what confuses me is that in that case the JLD2 conversion fails.

JonasIsensee commented 1 year ago

JLD2 just calls convert and writeas. There is no magic and writeas(::Type{Point}) is the wrong method signature e.g.

julia> p = Point(1,2)
Point{Int64, Int64}(1, 2)

julia> typeof(p)
Point{Int64, Int64}

julia> typeof(p) isa Type{Point}
false

julia> typeof(p) isa Type{<:Point}
true

so instead, JLD2 calls the fallback implementation which is

writeas(T) = T

so no conversion is done.

HumpyBlumpy commented 1 year ago

I tried used writeas(::Type{<:Point}) as you indicated.

To be concrete: With

JLD2.writeas(::Type{<:Point}) = PointSerial
Base.convert(::Type{PointSerial}, b::Point)  = PointSerial(b.x)
Base.convert(::Type{<:Point}, x::PointSerial) = Point(x.x,0.0)

I can successfully call convert(PointSerial,...) without specifyling the subtype, but saving with JLD2 returns

ERROR: MethodError: no method matching hasdata(::Type{PointSerial})

If instead we use

JLD2.writeas(::Type{<:Point{A}}) where A = PointSerial{A}
Base.convert(::Type{PointSerial{A}}, b::Point{A}) where A = PointSerial(b.x)
Base.convert(::Type{<:Point{A}}, b::PointSerial{A}) where A = Point(b.x,0.)

then JLD2 saving and loading works, but conversion fails without specifying the subtypes.

JonasIsensee commented 1 year ago

Yes, you can call that method yourself, but it is still not correct, since PointSerial is also a parametric type.

julia> typeof(thing_converted) isa Type{PointSerial}
false

julia> typeof(thing_converted) isa Type{<:PointSerial}
true

julia> typeof(thing_converted) isa Type{PointSerial{Int}}
true

The simplest definition that will do the correct thing, I think, should be

JLD2.writeas(::Type{Point{A}}) where A = PointSerial{A}
Base.convert(::Type{<:PointSerial}, b::Point) = PointSerial(b.x)
Base.convert(::Type{Point{A}}, b::PointSerial{A}) where A = Point(b.x,0.)