probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.79k stars 160 forks source link

Help with split-merge involutive MCMC for a tree-structured change-point model #392

Open deoxyribose opened 3 years ago

deoxyribose commented 3 years ago

Hi,

I'm trying to do involutive MCMC for a 2d non-parametric change-point model (the same one as in #388). The model samples a tree structure that recursively divides an image into rectangles, and then samples pixel values for the whole image given the structure. Model code here, the relevant bits are

@gen function grow_tree()
    if @trace(bernoulli(0.5), :isleaf)
        mean = @trace(normal(0, 1), :mean)
        variance = @trace(gamma(1, 1), :variance)
        return LeafNode(mean, variance)
    else
        frac = logistic(@trace(normal(0, 1), :frac))
        ishorizontal = @trace(bernoulli(0.5), :ishorizontal)
        a = @trace(grow_tree(), :a)
        b = @trace(grow_tree(), :b)
        return BranchNode(a,b,frac,ishorizontal)
    end
end;

@gen function screen_model(size::Tuple{Int64,Int64})
    nrows, ncols = size
    screenshot = Array{Float64}(undef, nrows, ncols)
    tree = @trace(grow_tree(), :tree)
    is,js = get_index_matrices(screenshot)
    img_params = map((i,j) -> get_value_at(i,j,tree,(1.,1.,Float64(nrows),Float64(ncols))), is, js)
    img_mean = getindex.(img_params,1)
    img_variance = getindex.(img_params,2)
    screenshot = @trace(broadcasted_normal([img_mean...],[img_variance...]), :img)
    return reshape(screenshot,size)
end

I wrote an involution inspired by the GP structure search in GenExamples.jl, where the proposal picks a random node in the tree, and samples a subtree that replaces whatever was at that node. Full code here.

@gen function subtree_proposal(prev_trace)
    prev_subtree_node::Node = prev_trace[:tree]
    (path::Vector{Symbol}) = @trace(pick_random_node_path(prev_subtree_node, Symbol[]), :choose_subtree_root)
    subtree_addr = isempty(path) ? :tree : (:tree => foldr(=>, path))
    new_subtree_node::Node = @trace(grow_tree(), :subtree) # mixed discrete / continuous
    (path, new_subtree_node)
end

@transform subtree_involution_tree_transform (model_in, aux_in) to (model_out, aux_out) begin
    (path::Vector{Symbol}, new_subtree_node) = @read(aux_in[], :discrete)

    # populate backward assignment with choice of root
    @copy(aux_in[:choose_subtree_root], aux_out[:choose_subtree_root])

    # swap subtrees
    model_subtree_addr = isempty(path) ? :tree : (:tree => foldr(=>, path))
    @copy(aux_in[:subtree], model_out[model_subtree_addr])
    @copy(model_in[model_subtree_addr], aux_out[:subtree])
end

This runs without errors, but few traces are accepted, probably because randomly proposed trees aren't a good fit, and even a well-fitting tree usually has bad values (:mean and :variance) at the leaf nodes.

So I want to write a split-merge involution which proposes to split a leaf node, or merge sibling leaf nodes, along with :mean and :variance parameters derived from the current trace, similar to the mixture example here. To begin with, I skip the parameter proposals and just propose a random split/merge, like this:

@gen function split_merge_proposal(prev_trace)
    tree = prev_trace[:tree]
    n_leaf_nodes = count_leaves(tree)
    random_split = @trace(bernoulli(0.5), :split)
    split = (n_leaf_nodes == 1) ? true : random_split
    if split
        # select random leaf for splitting
        leaf_path = @trace(pick_random_leaf(tree), :leaf_path)
        new_node = @trace(make_branch(), :new_node)
    else
        # select random leaf that will merge with sibling
        # sibling needs to be a leaf as well
        leaf_path = @trace(pick_random_leaf_parent(tree), :leaf_path)
        new_node = @trace(make_leaf(), :new_node)
    end
    (leaf_path, new_node)
end

@transform split_merge_transform (model_in, aux_in) to (model_out, aux_out) begin
    #(leaf_path::Array{Symbol,1}, new_node) = @read(aux_in[], :discrete)
    leaf_path = @read(aux_in[:leaf_path], :discrete)
    new_node = @read(aux_in[:new_node], :discrete)

    tree = @read(model_in[:tree], :discrete)
    n_leaf_nodes = count_leaves(tree)
    random_split = @read(aux_in[:split], :discrete)
    split = (n_leaf_nodes == 1) ? true : random_split

    new_node_addr = isempty(leaf_path) ? :tree : (:tree => foldr(=>, leaf_path))

    #leaf_path = @read(aux_in[:leaf_path], :discrete)
    #@copy(aux_in[:leaf_path], aux_out[:leaf_path])
    @write(aux_out[:leaf_path], leaf_path, :discrete)

    @write(aux_out[:split], !random_split, :discrete)
    @copy(aux_in[:new_node], model_out[new_node_addr])
    @copy(model_in[new_node_addr], aux_out[:new_node])
end

Full code here

Just like :choose_subtree_root in the subtree involution, to perform the correct reversal, aux_out needs the same value at :leaf_path that aux_in had, so I did exactly the same @copy(aux_in[:leaf_path], aux_out[:leaf_path]) But here, it runs for a few iterations and then crashes on ERROR: transform round trip check failed because the copied :leaf_path would, for some reason, differ from the original one, and so the new node that was proposed was spliced back at the wrong address, resulting in different model traces before and after the round trip. I tried with @write(aux_out[:leaf_path], leaf_path, :discrete) instead, and now I get ERROR: KeyError: key :leaf_path not found for either
leaf_path = @trace(pick_random_leaf(tree), :leaf_path) or leaf_path = @trace(pick_random_leaf_parent(tree), :leaf_path) in the proposal.

I hope this makes sense, let me know if I should clarify anything. Ultimately I want to do split-merge jumps that derive the leaf node values from the previous state (e.g. take the means of the mean and variance parameters of two merging leaf nodes), and then do MAP optimization on all the continuous parameters. But maybe there's a better way to do inference?

Thanks in advance for any help.

marcoct commented 3 years ago

I think the issue might be that pick_random_leaf is not sampling a single random choice, but is itself a generative function that samples multiple random choices. If that's the case, you can either make pick_random_leaf into a Gen.Distribution (https://www.gen.dev/dev/ref/extending/#custom_distributions-1) or you can extend the transform with a call to a recursive transform (that you call with @tcall, see https://www.gen.dev/dev/ref/trace_translators/#Trace-Transform-DSL-1) that walks the tree and populates values of all the choices made by pick_random_leaf (see https://github.com/probcomp/GenExamples.jl/blob/main/gp_structure/involution_mh.jl#L91-L155 for an example of that approach).

Note that while you can @copy entire choice maps from one trace to another, you can only @read and @write individual random choices.

deoxyribose commented 3 years ago

Thank you for the answer! I'm not sure if it's exactly what you meant, but I tried the recursive transform approach like this:

@gen function pick_random_leaf(node::Node, cur::Int, depth::Int)
    if isa(node, LeafNode)
        (cur, depth)
    elseif @trace(bernoulli(0.5), :recurse_a => cur)
        @trace(pick_random_leaf(node.a, get_child(cur, 1, 2), depth+1))
    else
        @trace(pick_random_leaf(node.b, get_child(cur, 2, 2), depth+1))
    end
end

@gen function pick_random_leaf_parent(node::Node, cur::Int, depth::Int)
    if isa(node.a, LeafNode) && isa(node.b, LeafNode)
        return (cur, depth)
    else
        if isa(node.a, BranchNode) && isa(node.b, BranchNode)
            recurse_a_prob = 0.5
        elseif isa(node.a, BranchNode) && !isa(node.b, BranchNode)
            recurse_a_prob = 1.
        elseif isa(node.b, BranchNode) && !isa(node.a, BranchNode)
            recurse_a_prob = 0.
        else
            error(node)
        end
    end
    if @trace(bernoulli(recurse_a_prob), :recurse_a => cur)
        @trace(pick_random_leaf_parent(node.a, get_child(cur, 1, 2), depth+1))
    else
        @trace(pick_random_leaf_parent(node.b, get_child(cur, 2, 2), depth+1))
    end
end

@gen function split_merge_proposal(prev_trace)
    tree = prev_trace[:tree]
    n_leaf_nodes = count_leaves(tree)
    random_split = @trace(bernoulli(0.5), :split)
    split = (n_leaf_nodes == 1) ? true : random_split
    if split
        # select random leaf for splitting
        leaf_path = @trace(pick_random_leaf(tree, 1, 0), :leaf_path)
        new_node = @trace(make_branch(), :new_node)
    else
        # select random leaf that will merge with sibling
        # sibling needs to be a leaf as well
        leaf_path = @trace(pick_random_leaf_parent(tree, 1, 0), :leaf_path)
        new_node = @trace(make_leaf(), :new_node)
    end
    (leaf_path, new_node)
end

@transform walk_tree(cur::Int, leaf_path::Array{Symbol,1}) (model_in, aux_in) to (model_out, aux_out) begin
    (leaf_number, leaf_depth) = @read(aux_in[:leaf_path], :discrete)
    if leaf_number == cur
        new_node_addr = isempty(leaf_path) ? :tree : (:tree => foldr(=>, leaf_path))
        @copy(aux_in[:new_node], model_out[new_node_addr])
        @copy(model_in[new_node_addr], aux_out[:new_node])
    else
        recurse_a = @read(aux_in[:leaf_path => :recurse_a => cur], :discrete)
        if recurse_a
            push!(leaf_path, :a)
            @tcall(walk_tree(get_child(cur, 1, 2), leaf_path))
        else
            push!(leaf_path, :b)
            @tcall(walk_tree(get_child(cur, 2, 2), leaf_path))
        end
    end
end

@transform split_merge_transform (model_in, aux_in) to (model_out, aux_out) begin
    (leaf_number, leaf_depth) = @read(aux_in[:leaf_path], :discrete)
    new_node = @read(aux_in[:new_node], :discrete)

    tree = @read(model_in[:tree], :discrete)
    n_leaf_nodes = count_leaves(tree)
    random_split = @read(aux_in[:split], :discrete)
    split = (n_leaf_nodes == 1) ? true : random_split

    @copy(aux_in[:leaf_path], aux_out[:leaf_path])
    @write(aux_out[:split], !random_split, :discrete)
    @tcall(walk_tree(1,Symbol[]))
end

It does pass the round trip checks, and usually gets many iterations in, but eventually crashes on ERROR: Did not visit all constraints Any idea what might be going wrong, or how I can inspect what constraints aren't visited and why?

I've also been trying to do map_optimize on a selection of the continuous choices in the trace, like this:

function select_continuous(node)
    selection = DynamicSelection()
    leaves = get_path_to_leaf(node)
    branches = []
    for leaf in leaves
        leaf_address = isempty(leaf) ? :tree : (:tree => foldr(=>, leaf))
        push!(selection, leaf_address => :mean)
        push!(selection, leaf_address => :variance)
        for (i,choice) in enumerate(leaf)
            if leaf[1:i] ∉ branches
                push!(branches, leaf[1:i])
                branch_address = isempty(leaf[1:i]) ? :tree : (:tree => foldr(=>, leaf[1:i]))
                push!(selection, branch_address => :frac)
            end
        end
    end
    return selection
end

function do_inference(model, img, n_iter)
    # condition on image
    observations = choicemap()
    observations[:img] = [img...]

    # generate initial trace consistent with observed data
    (trace, _) = generate(model, (size(img),), observations)

    continous_variables = select_continuous(trace[:tree])
    # do MCMC
    for iter=1:n_iter
        # do MH move on the subtree
        trace = replace_split_merge_move(trace)
        # optimize continuous variables
        trace = map_optimize(trace, continuous_variables)
    end
    return trace
end;

but this runs into lots of different bugs like ERROR: ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64,Float64,Nothing} to Float64 is not defined. Please useReverseDiff.valueinstead. and ERROR: AssertionError: length(arr) >= start_idx which makes me think there's something fundamentally wrong with how I'm setting this up.

In any case, thank you for your time!