TensorBFS / TensorInference.jl

Probabilistic inference using contraction of tensor networks
https://tensorbfs.github.io/TensorInference.jl/
MIT License
18 stars 2 forks source link

different behavior of `argmax` for Vector and Matrix leads to error #91

Closed ArrogantGao closed 8 months ago

ArrogantGao commented 8 months ago

In the function

function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Real, Vector}
    expected_mars = [[l] for l in get_vars(tn)]
    @assert tn.mars[1:length(expected_mars)] == expected_mars "To get the the most probable configuration, the leading elements of `tn.vars` must be `$expected_mars`"
    vars = get_vars(tn)
    tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false))
    logp, grads = cost_and_gradient(tn.code, tensors)
    # use Array to convert CuArray to CPU arrays
    return content(Array(logp)[]), map(k -> haskey(tn.evidence, vars[k]) ? tn.evidence[vars[k]] : argmax(grads[k]) - 1, 1:length(vars))
end

the function argmax is used and its behavior is different for Vector and Matrix

julia> argmax(randn(4))
2

julia> argmax(randn(4, 4))
CartesianIndex(4, 2)

so that argmax(grads[k]) - 1 leads to error if grads[k] is Matrix.

ArrogantGao commented 8 months ago

Sorry it was my mistake, grads[k] should not be a matrix. When creating the tensor network for MPE, there must be a I vectors for each variable, for example:

julia> factors
10-element Vector{TensorInference.Factor{Float64, 2}}:
 TensorInference.Factor{Float64, 2}((1, 2), [0.1853742993051539 0.24266609993460198; 0.3139558820831363 0.258003718677108])
 TensorInference.Factor{Float64, 2}((1, 3), [0.19490302071525448 0.2331373785245014; 0.37083643432205143 0.2011231664381927])
 TensorInference.Factor{Float64, 2}((1, 4), [0.19159768973795085 0.236442709501805; 0.2699053176129627 0.30205428314728144])
 TensorInference.Factor{Float64, 2}((1, 5), [0.16153413648476111 0.2665062627549947; 0.29757932897291595 0.2743802717873282])
 TensorInference.Factor{Float64, 2}((2, 3), [0.31429995086808404 0.18503023052020612; 0.25143950416922195 0.24923031444248797])
 TensorInference.Factor{Float64, 2}((2, 4), [0.23558618140615134 0.2637439999821388; 0.22591682594476228 0.27475299266694764])
 TensorInference.Factor{Float64, 2}((2, 5), [0.2561260234949363 0.24320415789335384; 0.2029874419627408 0.29768237664896907])
 TensorInference.Factor{Float64, 2}((3, 4), [0.2893861660460715 0.2763532889912345; 0.17211684130484214 0.2621437036578519])
 TensorInference.Factor{Float64, 2}((3, 5), [0.27993586794877484 0.28580358708853115; 0.17917759750890222 0.2550829474537918])
 TensorInference.Factor{Float64, 2}((4, 5), [0.24961217921386344 0.21189082813705018; 0.20950128624381362 0.32899570640527276])

julia> tn.tensors
15-element Vector{Array{Float64}}:
 [1.0, 1.0]
 [1.0, 1.0]
 [1.0, 1.0]
 [1.0, 1.0]
 [1.0, 1.0]
 [0.1853742993051539 0.24266609993460198; 0.3139558820831363 0.258003718677108]
 [0.19490302071525448 0.2331373785245014; 0.37083643432205143 0.2011231664381927]
 [0.19159768973795085 0.236442709501805; 0.2699053176129627 0.30205428314728144]
 [0.16153413648476111 0.2665062627549947; 0.29757932897291595 0.2743802717873282]
 [0.31429995086808404 0.18503023052020612; 0.25143950416922195 0.24923031444248797]
 [0.23558618140615134 0.2637439999821388; 0.22591682594476228 0.27475299266694764]
 [0.2561260234949363 0.24320415789335384; 0.2029874419627408 0.29768237664896907]
 [0.2893861660460715 0.2763532889912345; 0.17211684130484214 0.2621437036578519]
 [0.27993586794877484 0.28580358708853115; 0.17917759750890222 0.2550829474537918]
 [0.24961217921386344 0.21189082813705018; 0.20950128624381362 0.32899570640527276]