ArrogantGao / TensorBranching.jl

Optimal branching via tensor network
MIT License
4 stars 1 forks source link

Usage of `BitStr` leads to huge recompiling. #11

Open ArrogantGao opened 3 months ago

ArrogantGao commented 3 months ago
julia> using TensorBranching

julia> g = random_regular_graph(120, 3);

julia> TensorBranching.mis_solver(g, SetCoverBranching(), D3Measure(), MinBoundSelector(2), EnvFilter())

julia> @time TensorBranching.mis_solver(g, SetCoverBranching(), D3Measure(), MinBoundSelector(2), EnvFilter())
  0.856605 seconds (6.00 M allocations: 487.990 MiB, 4.02% gc time, 47.59% compilation time)
CountingMIS(53, 125)

julia> @code_warntype TensorBranching.mis_solver(g, SetCoverBranching(), D3Measure(), MinBoundSelector(2), EnvFilter())
MethodInstance for TensorBranching.mis_solver(::SimpleGraph{Int64}, ::SetCoverBranching, ::D3Measure, ::MinBoundSelector, ::EnvFilter)
  from mis_solver(g::SimpleGraph, strategy::AbsractBranching, measurement::AbstractMeasurement, vertex_select::AbstractVertexSelector, filter::AbstractTruthFilter) @ TensorBranching ~/code/TensorBranching.jl/src/solver.jl:6
Arguments
  #self#::Core.Const(TensorBranching.mis_solver)
  g::SimpleGraph{Int64}
  strategy::SetCoverBranching
  measurement::Core.Const(D3Measure())
  vertex_select::MinBoundSelector
  filter::Core.Const(EnvFilter())
Locals
  @_7::Union{Nothing, Tuple{Int64, Int64}}
  @_8::Int64
  #82::TensorBranching.var"#82#84"
  #81::TensorBranching.var"#81#83"
  max_mis::CountingMIS
  mis_count::Vector{CountingMIS}
  dnf::DNF{_A, Int64} where _A
  openvertices::Vector{Int64}
  vertices::Vector{Int64}
  v::Union{Nothing, Int64}
  dg::Vector{Int64}
  i::Int64
  gi::SimpleGraph{Int64}
  rvs::Vector{Int64}
  clause::Clause{_A, Int64} where _A
Body::CountingMIS
1 ──        Core.NewvarNode(:(@_7))
│           Core.NewvarNode(:(@_8))
│           Core.NewvarNode(:(#82))
│           Core.NewvarNode(:(#81))
│           Core.NewvarNode(:(max_mis))
│           Core.NewvarNode(:(mis_count))
│           Core.NewvarNode(:(dnf))
│           Core.NewvarNode(:(openvertices))
│           Core.NewvarNode(:(vertices))
│           Core.NewvarNode(:(v))
│           (dg = TensorBranching.degree(g))
│    %12  = TensorBranching.nv(g)::Int64
│    %13  = (%12 == 0)::Bool
└───        goto #3 if not %13
2 ──        goto #4
3 ── %16  = TensorBranching.nv(g)::Int64
│    %17  = (%16 == 1)::Bool
└───        goto #5 if not %17
4 ┄─ %19  = TensorBranching.nv(g)::Int64
│    %20  = TensorBranching.CountingMIS(%19)::Core.PartialStruct(CountingMIS, Any[Int64, Core.Const(1)])
└───        return %20
5 ── %22  = (0 ∈ dg)::Bool
└───        goto #7 if not %22
6 ──        goto #8
7 ── %25  = (1 ∈ dg)::Bool
└───        goto #9 if not %25
8 ┄─        (#81 = %new(TensorBranching.:(var"#81#83")))
│    %28  = #81::Core.Const(TensorBranching.var"#81#83"())
│           (v = TensorBranching.findfirst(%28, dg))
│    %30  = TensorBranching.nv(g)::Int64
│    %31  = (1:%30)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│    %32  = v::Union{Nothing, Int64}
│    %33  = TensorBranching.neighbors(g, v)::Vector{Int64}
│    %34  = (%32 ∪ %33)::Vector{Int64}
│    %35  = TensorBranching.setdiff(%31, %34)::Vector{Int64}
│    %36  = TensorBranching.induced_subgraph(g, %35)::Tuple{SimpleGraph{Int64}, Vector{Int64}}
│    %37  = Base.getindex(%36, 1)::SimpleGraph{Int64}
│    %38  = TensorBranching.mis_solver(%37, strategy, measurement, vertex_select, filter)::CountingMIS
│    %39  = (1 + %38)::CountingMIS
└───        return %39
9 ── %41  = (2 ∈ dg)::Bool
└───        goto #11 if not %41
10 ─        (#82 = %new(TensorBranching.:(var"#82#84")))
│    %44  = #82::Core.Const(TensorBranching.var"#82#84"())
│           (v = TensorBranching.findfirst(%44, dg))
│    %46  = TensorBranching.folding(g, v, strategy, measurement, vertex_select, filter)::CountingMIS
└───        return %46
11 ─ %48  = TensorBranching.maximum(dg)::Int64
│    %49  = (%48 ≥ 6)::Bool
└───        goto #13 if not %49
12 ─        (v = TensorBranching.argmax(dg))
│    %52  = TensorBranching.nv(g)::Int64
│    %53  = (1:%52)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│    %54  = v::Int64
│    %55  = TensorBranching.neighbors(g, v::Int64)::Vector{Int64}
│    %56  = (%54 ∪ %55)::Vector{Int64}
│    %57  = TensorBranching.setdiff(%53, %56)::Vector{Int64}
│    %58  = TensorBranching.induced_subgraph(g, %57)::Tuple{SimpleGraph{Int64}, Vector{Int64}}
│    %59  = Base.getindex(%58, 1)::SimpleGraph{Int64}
│    %60  = TensorBranching.mis_solver(%59, strategy, measurement, vertex_select, filter)::CountingMIS
│    %61  = (1 + %60)::CountingMIS
│    %62  = TensorBranching.nv(g)::Int64
│    %63  = (1:%62)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│    %64  = TensorBranching.setdiff(%63, v::Int64)::Vector{Int64}
│    %65  = TensorBranching.induced_subgraph(g, %64)::Tuple{SimpleGraph{Int64}, Vector{Int64}}
│    %66  = Base.getindex(%65, 1)::SimpleGraph{Int64}
│    %67  = TensorBranching.mis_solver(%66, strategy, measurement, vertex_select, filter)::CountingMIS
│    %68  = TensorBranching.max(%61, %67)::CountingMIS
└───        return %68
13 ─ %70  = TensorBranching.optimal_branching_dnf(g, strategy, measurement, vertex_select, filter)::Tuple{Vector{Int64}, Vector{Int64}, DNF{_A, Int64} where _A}
│    %71  = Base.indexed_iterate(%70, 1)::Core.PartialStruct(Tuple{Vector{Int64}, Int64}, Any[Vector{Int64}, Core.Const(2)])
│           (vertices = Core.getfield(%71, 1))
│           (@_8 = Core.getfield(%71, 2))
│    %74  = Base.indexed_iterate(%70, 2, @_8::Core.Const(2))::Core.PartialStruct(Tuple{Vector{Int64}, Int64}, Any[Vector{Int64}, Core.Const(3)])
│           (openvertices = Core.getfield(%74, 1))
│           (@_8 = Core.getfield(%74, 2))
│    %77  = Base.indexed_iterate(%70, 3, @_8::Core.Const(3))::Core.PartialStruct(Tuple{DNF{_A, Int64} where _A, Int64}, Any[DNF{_A, Int64} where _A, Core.Const(4)])
│           (dnf = Core.getfield(%77, 1))
│    %79  = Core.apply_type(TensorBranching.Vector, TensorBranching.CountingMIS)::Core.Const(Vector{CountingMIS})
│    %80  = TensorBranching.undef::Core.Const(UndefInitializer())
│    %81  = Base.getproperty(dnf, :clauses)::Array{Clause{_A, Int64}, 1} where _A
│    %82  = TensorBranching.length(%81)::Int64
│           (mis_count = (%79)(%80, %82))
│    %84  = Base.getproperty(dnf, :clauses)::Array{Clause{_A, Int64}, 1} where _A
│    %85  = TensorBranching.length(%84)::Int64
│    %86  = (1:%85)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│           (@_7 = Base.iterate(%86))
│    %88  = (@_7 === nothing)::Bool
│    %89  = Base.not_int(%88)::Bool
└───        goto #16 if not %89
14 ┄ %91  = @_7::Tuple{Int64, Int64}
│           (i = Core.getfield(%91, 1))
│    %93  = Core.getfield(%91, 2)::Int64
│    %94  = Base.getproperty(dnf, :clauses)::Array{Clause{_A, Int64}, 1} where _A
│           (clause = Base.getindex(%94, i))
│           (rvs = TensorBranching.removed_vertices(vertices, g, clause))
│           (gi = TensorBranching.copy(g))
│           TensorBranching.rem_vertices!(gi, rvs)
│    %99  = TensorBranching.mis_solver(gi, strategy, measurement, vertex_select, filter)::CountingMIS
│    %100 = Base.getproperty(clause, :val)::DitStr{2, _A, Int64} where _A
│    %101 = TensorBranching.count_ones(%100)::Int64
│    %102 = (%99 + %101)::CountingMIS
│           Base.setindex!(mis_count, %102, i)
│           (@_7 = Base.iterate(%86, %93))
│    %105 = (@_7 === nothing)::Bool
│    %106 = Base.not_int(%105)::Bool
└───        goto #16 if not %106
15 ─        goto #14
16 ┄        (max_mis = TensorBranching.maximum(mis_count))
└───        return max_mis
ArrogantGao commented 3 months ago

The high light here is not correct, in repl Array{Clause{_A, Int64}, 1} where _A is label as red. The length of the BitStr, _A is always changing, which leads to huge recompiling. Can this be solved or should we just ignore that ? @GiggleLiu

GiggleLiu commented 3 months ago

I think it can be resolved by using the LongLongUInt type in BitBasis directly. It has a type parameter C to specify how many UInt64 are required to store a long bitstring.

GiggleLiu commented 3 months ago

Another main source of compilation time is the generic tensor networks based MIS solver. We should consider using a bruteforce solver.

ArrogantGao commented 3 months ago

I think it can be resolved by using the LongLongUInt type in BitBasis directly. It has a type parameter C to specify how many UInt64 are required to store a long bitstring.

Here N is the length of the bit string, corresponding to the number of vertices selected for calculating the reduced alpha tensor, so that can bot be fixed.