TensorBFS / TensorInference.jl

Probabilistic inference using contraction of tensor networks
https://tensorbfs.github.io/TensorInference.jl/
MIT License
18 stars 2 forks source link

Let marginals return dict #61

Closed GiggleLiu closed 1 year ago

GiggleLiu commented 1 year ago

The updated docstring

Query the marginals of the variables in a TensorNetworkModel. The returned value is a dictionary of variables and their marginals, where a marginal is a joint probability distribution over the associated variables. By default, the marginals of all individual variables are returned. The marginal variables to query can be specified when constructing TensorNetworkModel as its field mars. It will affect the contraction order of the tensor network.

Arguments

Example

The following example is from examples/asia/main.jl.

julia> model = read_model_file(pkgdir(TensorInference, "examples", "asia", "asia.uai"));

julia> tn = TensorNetworkModel(model; evidence=Dict(1=>0))
TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
contraction time = 2^6.022, space = 2^2.0, read-write = 2^7.077

julia> marginals(tn)
Dict{Vector{Int64}, Vector{Float64}} with 8 entries:
  [8] => [0.450138, 0.549863]
  [3] => [0.5, 0.5]
  [1] => [1.0]
  [5] => [0.45, 0.55]
  [4] => [0.055, 0.945]
  [6] => [0.10225, 0.89775]
  [7] => [0.145092, 0.854908]
  [2] => [0.05, 0.95]

julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]])
TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
contraction time = 2^7.781, space = 2^5.0, read-write = 2^8.443

julia> marginals(tn2)
Dict{Vector{Int64}, Matrix{Float64}} with 2 entries:
  [2, 3] => [0.025 0.025; 0.475 0.475]
  [3, 4] => [0.05 0.45; 0.005 0.495]

In this example, we first set the evidence of variable 1 to 0, then we query the marginals of all individual variables. The returned values is a dictionary, the key are query variables, and the value are the corresponding marginals. The marginals are vectors, with its entries corresponding to the probability of the variable taking the value 0 and 1, respectively. For evidence variable 1, the marginal is always [1.0], since it is fixed to 0.

Then we set the marginal variables to query to be variable 2 and 3, and variable 3 and 4, respectively. The joint marginals may or may not increase the contraction time and space. Here, the contraction space complexity is increased from 2^2.0 to 2^5.0, and the contraction time complexity is increased from 2^5.977 to 2^7.781. The output marginals are joint probabilities of the query variables represented by tensors.

codecov[bot] commented 1 year ago

Codecov Report

Merging #61 (6266d7c) into main (a066560) will increase coverage by 3.00%. The diff coverage is 66.66%.

@@            Coverage Diff             @@
##             main      #61      +/-   ##
==========================================
+ Coverage   81.61%   84.61%   +3.00%     
==========================================
  Files          10       10              
  Lines         533      533              
==========================================
+ Hits          435      451      +16     
+ Misses         98       82      -16     
Files Changed Coverage Δ
src/mar.jl 94.54% <66.66%> (ø)

... and 1 file with indirect coverage changes