TuringLang / AbstractPPL.jl

Common types and interfaces for probabilistic programming
http://turinglang.org/AbstractPPL.jl/
MIT License
27 stars 7 forks source link

Conversion of VarName to/from string #100

Closed penelopeysm closed 1 month ago

penelopeysm commented 1 month ago

Closes #98.

penelopeysm commented 1 month ago

@sunxd3 @yebai I'm not entirely sure what interface we want for this, so would like to ask you all for feedback. You can see what it currently does in the tests I added. Specifically:

Alternatively, we could go the whole way and serialise the entire internal structure of VarName. That would be the most correct thing to do, although I don't currently get the sense we need to go that far.

yebai commented 1 month ago

I've chosen to do varname_from_str(), but would it be better to instead make a new method VarName(::AbstractString)

Is it sensible to overload Base.Serialization.serialize and counterpart derialize for this purpose?

serialise the entire internal structure of VarName

I'm not sure I understand this; can you clarify?

penelopeysm commented 1 month ago

Base.Serialization

I had a look at that earlier, but it skips the 'string' bit and goes straight to 'IO stream'

serialize(stream::IO, value)
deserialize(stream)

meaning that it's not possible to use these methods as part of something larger (e.g. de/serialising a struct that itself contains VarNames). Basically it's not composable, or I think, very difficult to use in a composable manner.

Other serialisation libraries like Serde are composable, but we probably don't want to bring in another dependency :(

internal structure

By this, I'm talking about storing the exact types that VarName uses internally, which would let us correctly differentiate between concretised colons and abstract colons.

julia> using AbstractPPL; x = ones(15)

julia> dump(@varname(x[:], true))
VarName{:x, Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}}}}
  optic: Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}}}
    indices: Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}}
      1: AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}
        range: Base.OneTo{Int64}
          stop: Int64 15

julia> dump(@varname(x[:]))
VarName{:x, Accessors.IndexLens{Tuple{Colon}}}
  optic: Accessors.IndexLens{Tuple{Colon}}
    indices: Tuple{Colon}
      1: Colon() (function of type Colon)
yebai commented 1 month ago

By this, I'm talking about storing the exact types that VarName uses internally, which would let us correctly differentiate between concretised colons and abstract colons.

I don't have a strong view on this; any thoughts from @mhauru @sunxd3 @torfjelde?

mhauru commented 1 month ago

I don't really understand the how concretisation gets used, so no strong opinion. I do agree with @penelopeysm's comment that dumping the entire structure, so one can reconstruct an object that passes a == comparison in all cases, sounds like the conceptually correct thing to do.

sunxd3 commented 1 month ago

A varname is concretized dynamically using the size information in the scope the concretization is carried out. One of the point is the varname will adapt to changing sizes. e.g.

julia> x = ones(10)
10-element Vector{Float64}:
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0

julia> vn1 = @varname(x[end])
x[10]

julia> x = ones(100)
100-element Vector{Float64}:
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 ⋮
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0

julia> vn2 = @varname(x[end])
x[100]

this is useful to support general Julia program in Turing model.

(note Accessors.need_dynamic_optic(:(x[:])) == false and Accessors.need_dynamic_optic(:(x[end])) == true.)

sunxd3 commented 1 month ago

regarding how we should convert varname to string, I think ideally, the conversion should be (1) easy to read (2) a bijection. (by bijection I mean: varname(string(vn)) == vn for all vn modulo the parse)

I think saving the whole internal would be good, but might makes it hard to read. It would be nice to have an concret example where this conversion might be used. (an example I can think of is from parameter names in Chain to varname which can be used with VarInfos.)

mhauru commented 1 month ago

an example I can think of is from parameter names in Chain to varname which can be used with VarInfos.

This was my use case that came up.

I think having the same format both machine-readable and human-readable is great, but if that's hard, we can always have one method for a full string representation and another for pretty-printing a summary.

penelopeysm commented 1 month ago

Thanks all!

The thing with Chains is that the varname will always be concretised:

using Turing

@model function demo(y)
    x = Vector{Real}(undef, 2)  # so that x[:] is concretised
    x[:] ~ MvNormal(fill(0, 2), 1)
    y ~ Normal(x[end], 1)
end

chain = sample(demo(1), NUTS(), 100, progress=false)

vn = collect(keys(chain.info.varname_to_symbol))[1]   # x[:][1]
dump(vn)

gives

AbstractPPL.VarName{:x, ComposedFunction{Accessors.IndexLens{Tuple{Int64}}, Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}}}}}
  optic: (@o _[:][1]) (function of type ComposedFunction{Accessors.IndexLens{Tuple{Int64}}, Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}}}})
    outer: Accessors.IndexLens{Tuple{Int64}}
      indices: Tuple{Int64}
        1: Int64 1
    inner: Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}}}
      indices: Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}}
        1: AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}
          range: Base.OneTo{Int64}
            stop: Int64 2

So for this usecase, it seems sensible to say that roundtrip deserialisation need only work for concretised varnames?

sunxd3 commented 1 month ago

say that roundtrip deserialisation need only work for concretised varnames

yeah, that would make sense to me. also x[:] with ConcretizedSlice and x[1:10] (the range doesn't matter here) should behave the same, so I don't mind discard ConcretizedSlice and just use simple Range. (I might be wrong on the same behavior bit, conner cases might exist).

torfjelde commented 1 month ago

I'm personally very much in favour of storing a "true" representation of VarName that can be deserialized to pass something like ==. It's unclear to me why we want readability in this case?

sunxd3 commented 1 month ago

if all we want is serialization, then yeah, storing the full thing would be the right way to go. but do we want to use this "string-ify" also for the String cast/conversion like that of Symbol(varname)?

also remind me of that we need to deal with non-standard variable name like var"x.a" so the unserialization wouldn't be confused.

torfjelde commented 1 month ago

but do we want to use this "string-ify" also for the String cast/conversion like that of Symbol(varname)?

Depends on what the application for this is (I'm somewhat OOL here).

IMO, there are currently two "use-cases" for string(varname):

  1. show
  2. As keys in MCMCChains.jl

(1) is fair, but requires no notion of "deserialization".

(2) is not a good motivation in my opinion. Instead, we should just be using VarName directly as keys in MCMCChains.Chains. VarName implements a hash, so AFAIK there's no reason why we wouldn't allow this.

penelopeysm commented 1 month ago

Current behaviour is much improved, I think, and works for both unconcretised and concretised slices, as well as the var"x.a" case (thanks for suggesting that @sunxd3!):

using AbstractPPL
using Printf

y = ones(10)
vns = [
    @varname(x),
    @varname(x.a),
    @varname(x.a.b),
    @varname(var"x.a"),
    @varname(x[1]),
    @varname(x[1:10]),
    @varname(x[1, 2]),
    @varname(x[:]),
    @varname(y[begin:end]),
    @varname(y[:], false),
    @varname(y[:], true),
]
for vn in vns
    @printf "%10s => %s\n" vn vn_to_string(vn)
end
         x => (sym = "x", optic = (type = "identity",))
       x.a => (sym = "x", optic = (type = "property", field = "a"))
     x.a.b => (sym = "x", optic = (type = "composed", outer = (type = "property", field = "b"), inner = (type = "property", field = "a")))
       x.a => (sym = "x.a", optic = (type = "identity",))
      x[1] => (sym = "x", optic = (type = "index", indices = "(1,)"))
   x[1:10] => (sym = "x", optic = (type = "index", indices = "(1:10,)"))
   x[1, 2] => (sym = "x", optic = (type = "index", indices = "(1, 2,)"))
      x[:] => (sym = "x", optic = (type = "index", indices = "(:,)"))
   y[1:10] => (sym = "y", optic = (type = "index", indices = "(1:10,)"))
      y[:] => (sym = "y", optic = (type = "index", indices = "(:,)"))
      y[:] => (sym = "y", optic = (type = "index", indices = "(ConcretizedSlice(Base.OneTo(10)),)"))

It fails if somebody tries to use custom arrays, – presumably because OffsetArrays is not in scope at the point where the call to eval is made. I think this is a reasonable limitation, though – the moment something is serialised to a string or a file, we lose all information about the context in which the varname was generated including what libraries / functions are in scope, so I think it's fair to say we can't reconstruct the context when deserialising.

julia> using OffsetArrays: Origin

julia> z = Origin(4)(ones(10))
10-element OffsetArray(::Vector{Float64}, 4:13) with eltype Float64 with indices 4:13:
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0

julia> vn = @varname(z[:], true)
z[:]

julia> vn_from_string(vn_to_string(vn)) == vn
ERROR: UndefVarError: `OffsetArrays` not defined
Stacktrace:
 [1] top-level scope
   @ none:1
 [2] eval
   @ ./boot.jl:385 [inlined]
 [3] eval
   @ ~/ppl/appl/src/AbstractPPL.jl:1 [inlined]
 [4] nt_to_optic(nt::@NamedTuple{type::String, indices::String})
   @ AbstractPPL ~/ppl/appl/src/varname.jl:783
 [5] vn_from_string(str::String)
   @ AbstractPPL ~/ppl/appl/src/varname.jl:808
 [6] top-level scope
   @ REPL[63]:1
penelopeysm commented 1 month ago

Okay, I've reimplemented the {de,}serialisation with StructTypes. Getting string output from StructTypes also requires an additional library e.g. JSON3 which is what I've gone with here.

In the benchmarks below, vn_from_string is the original implementation using eval(Meta.parse(...)), and vn_from_string2 is the new one. I've avoided pirating Accessors types by going via an intermediate dictionary representation, so that instead of serialising optics directly we convert them to a dict and serialise the dict.

Click here to see the benchmarking code. ```julia using AbstractPPL using BenchmarkTools y = ones(10) z = ones(10, 5) vns = [ @varname(x), @varname(x.a), @varname(x.a.b), @varname(var"x.a"), @varname(x[1]), @varname(x[1:10]), @varname(x[1, 2]), @varname(x[:]), @varname(y[begin:end]), @varname(y[:], false), @varname(y[:], true), @varname(z[:], false), @varname(z[:], true), @varname(z[:,:], false), @varname(z[:,:], true), ] # Method 1 println("Old") strs = map(vn_to_string, vns) @benchmark map(vn_from_string, strs) # Method 2 println("New") strs = map(vn_to_string2, vns) @benchmark map(vn_from_string2, strs) ```
julia> @benchmark map(vn_from_string, strs)
BenchmarkTools.Trial: 1649 samples with 1 evaluation.
 Range (min … max):  2.873 ms … 74.691 ms  ┊ GC (min … max): 0.00% … 95.75%
 Time  (median):     2.920 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.032 ms ±  1.772 ms  ┊ GC (mean ± σ):  1.53% ±  2.80%

 Memory estimate: 242.80 KiB, allocs estimate: 4126.

julia> @benchmark map(vn_from_string2, strs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  135.750 μs …  66.464 ms  ┊ GC (min … max): 0.00% … 99.57%
 Time  (median):     138.666 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   149.350 μs ± 665.882 μs  ┊ GC (mean ± σ):  5.97% ±  3.66%

 Memory estimate: 105.80 KiB, allocs estimate: 1285.

The only cost is that this drops support for any ConcretizedSlice that doesn't contain Base.Slice(Base.OneTo(n)) as its range. We could in principle expand this, but I can't think of any setting where this wouldn't be the case (see tests).

Note that the original implementation already didn't permit ranges that were not defined in Base, so the only question would be whether there are other ranges in Base that we need to care about.

penelopeysm commented 1 month ago

Remaining todos if we're happy with the implementation:

torfjelde commented 1 month ago

Great stuff @penelopeysm :)

I've avoided pirating Accessors types by going via an intermediate dictionary representation, so that instead of serialising optics directly we convert them to a dict and serialise the dict.

If we're lucky, they might be open to adding StructTypes.jl as a conditional dependency in Accessors.jl, which would probably make this even nicer:)

torfjelde commented 1 month ago

At this point, would it make sense to just move all the varname related stuff to a separate package, e.g. VarNames.jl? This way we could easily use it in MCMCChains for example.

penelopeysm commented 1 month ago

adding StructTypes.jl as a conditional dependency in Accessors.jl

Hmm yes an AccessorsStructTypesExt would make sense! One problem is that we are only really implementing serialisation for a subset of the types in Accessors (i.e. the types we care about). I could open an issue to see whether they would be interested in the code.

a separate package, e.g. VarNames.jl

AbstractPPL doesn't really have much code that isn't in varname.jl, so moving it elsewhere would leave this really empty. I wasn't around for the discussions about why this package was created, though, so happy to leave this decision to you all 😄

sunxd3 commented 1 month ago

I think the tradeoff between the new dependency and speed of deserialize is worth it, so the new implementation is good with me. This is fantastic, Penny!

This way we could easily use it in MCMCChains for example.

I am probably missing something, why is it difficult now? (I think MCMCChains and AbstractPPL don't depend on each other)

penelopeysm commented 1 month ago

Actually, coming back to this with a fresh pair of eyes, here's an even faster one which just goes via Dict -> JSON instead of using StructTypes:

julia> @benchmark map(vn_from_string2, strs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  68.208 μs …  62.339 ms  ┊ GC (min … max): 0.00% … 99.75%
 Time  (median):     70.583 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   77.770 μs ± 623.365 μs  ┊ GC (mean ± σ):  8.61% ±  1.89%

 Memory estimate: 42.80 KiB, allocs estimate: 677.
torfjelde commented 1 month ago

I am probably missing something, why is it difficult now? (I think MCMCChains and AbstractPPL don't depend on each other)

AbstractPPL.jl has a bunch of features that are completely unnecessary for MCMCChains.jl. There are several places where it would be beneficial to use VarNames but just seems "too much" to depend on AbstractPPL.jl.

torfjelde commented 1 month ago

Actually, coming back to this with a fresh pair of eyes, here's an even faster one which just goes via Dict -> JSON instead of using StructTypes:

Do we lose some extensibility here though? My "hope" was that if the "serialization interface" is external from AbstractPPL.jl, we'd be able to suggest other packages to also implement it through and extension or something, leading to VarName being fully serializable even when more complex indices, etc. are used in a VarName.

penelopeysm commented 1 month ago

Actually the StructTypes implementation wasn't very extensible either 😅 because it doesn't use StructTypes 'all the way down' – it just uses it for the top layer, and the bottom layers (most importantly, indices) are serialised manually:

https://github.com/TuringLang/AbstractPPL.jl/blob/73a426b04339cf9d7b56c55122e41558295cf855/src/varname.jl#L866-L899

For composability purposes I see a couple of options we can do:

  1. Rewrite such that it does use StructTypes all the way down. If we want to avoid type piracy, this would mean we do have to patch everything upstream: for example Accessors types would go in AccessorsStructTypesExt, OffsetArrays would go in OffsetArraysStructTypesExt, and Base range types would go into StructTypes itself.

  2. Make some small modifications to the two functions above so that they can be extended to new types by defining a new method. The responsibility for hosting the code for new types would then fall on us: for example, if we wanted to add OffsetArrays compatibility we would add an AbstractPPL.index_to_dict(::OffsetArrays.IdOffsetRange) (and the inverse) into an AbstractPPLOffsetArraysExt.

Option 1 would be the correct solution if we believe that there is one true serialisation method for those upstream types, which every consumer should work with. I think this is a noble aim! but I feel very hesitant in proposing so many upstream changes :(

As a case study, our ConcretizedSlices (usually) wrap Base.OneTo, and it turns out that StructTypes already defines a serialisation method for that:

julia> using StructTypes, JSON3

julia> s = JSON3.write(Base.OneTo(5))
"[1,2,3,4,5]"

Unfortunately, deserialising this results in a loss of information, because there is no information in the string that says that it should be converted to Base.OneTo{Int}. (JSON3.read can take a target type as the second argument, but there's no way to find out the appropriate type from the string.)

julia> JSON3.read(s)
5-element JSON3.Array{Int64, Base.CodeUnits{UInt8, String}, Vector{UInt64}}:
 1
 2
 3
 4
 5

and although these are semantically the same (I discovered that Base.OneTo(5) == [1, 2, 3, 4, 5] is true 😄), it means we can't actually reconstruct a ConcretizedSlice:

julia> ConcretizedSlice(Base.Slice(JSON3.read(s)))
ERROR: MethodError: no method matching Base.Slice(::JSON3.Array{Int64, Base.CodeUnits{UInt8, String}, Vector{UInt64}})

Closest candidates are:
  Base.Slice(::Base.Slice)
   @ Base indices.jl:377
  Base.Slice(::T) where T<:AbstractUnitRange
   @ Base indices.jl:375

To fix this, we would have to patch StructTypes to serialise Base.OneTo(5) as something like

"{\"stop\":5,\"type\":\"Base.OneTo\"}"

so that we could use that type information in the deserialisation. And although that would indeed be a more faithful representation(!), this is the sort of thing I'm hesitant to do – I feel 'safer' keepibg the serialisation function within AbstractPPL, which lets us be more opinionated and customise the exact behaviour to meet our needs (as opposed to the needs of everybody who imports StructTypes).

Let's discuss this on Monday though :)

penelopeysm commented 1 month ago

I think this is ready for a final review. I wrote up docs here too which show how you can extend the interface to handle custom types: http://turinglang.org/AbstractPPL.jl/previews/PR100/api/#VarName-serialisation