EnzymeAD / Reactant.jl

MIT License
58 stars 5 forks source link

Infinte recursion when compiling a function #168

Open jofrevalles opened 2 weeks ago

jofrevalles commented 2 weeks ago

CC @mofeing I have been trying to compile a function that computes a contract operation given some parameters. Nevertheless, I got a strange StackOverflowError. Context:

using Yao
using Enzyme
using Reactant
using Tenet
using EinExprs
using Adapt
# Define the ansatz circuit
function ansatz(params)
    chain(2,
        put(1=>Ry(params[1])),
        put(2=>Ry(params[2])),
        # control(1, 2=>Ry(params[3])),
        put(1=>Ry(params[4])),
        put(2=>Ry(params[5]))
    )
end

ansatz_circ = ansatz(rand(5))

# Function to compute the expectation value
function expectation(params)
    H = chain(2,
        put(1=>Z),
        put(2=>Z),
        # control(1, 2=>X)
    )
    circuit = ansatz(params)

    qcirc = Tenet.Quantum(circuit)
    H = Tenet.Quantum(H)

    psi000 = Tenet.Quantum(Tenet.Product(fill([1, 0], 2))) #|00>
    state = merge(psi000, qcirc) # circuit|00>

    expval = merge(state, H, state')
    return expval
end

qtn = expectation(rand(5))
tn = Tenet.TensorNetwork(expectation(rand(5)))
path = einexpr(tn; optimizer=Greedy())

function loss(params)
    tn = Tenet.TensorNetwork(expectation(params))
    # path = einexpr(tn; optimizer=Greedy())
    return Tenet.contract(tn; path)
end

params = rand(5)
params′ = adapt(ConcreteRArray, params)

Then:


julia> f = @compile loss(params′)
┌ Warning: Performing scalar indexing on task Task (runnable) @0x000071040fa00010.
│ Invocation resulted in scalar indexing of a TracedRArray.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on device, but very slowly on the CPU,
│ and require expensive copies and synchronization each time and therefore should be avoided.
└ @ Reactant ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:53
┌ Warning: Performing scalar indexing on task Task (runnable) @0x000071040fa00010.
│ Invocation resulted in scalar indexing of a TracedRArray.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on device, but very slowly on the CPU,
│ and require expensive copies and synchronization each time and therefore should be avoided.
└ @ Reactant ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:53
┌ Warning: Performing scalar indexing on task Task (runnable) @0x000071040fa00010.
│ Invocation resulted in scalar indexing of a TracedRArray.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on device, but very slowly on the CPU,
│ and require expensive copies and synchronization each time and therefore should be avoided.
└ @ Reactant ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:53
┌ Warning: Performing scalar indexing on task Task (runnable) @0x000071040fa00010.
│ Invocation resulted in scalar indexing of a TracedRArray.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on device, but very slowly on the CPU,
│ and require expensive copies and synchronization each time and therefore should be avoided.
└ @ Reactant ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:53
┌ Warning: Performing scalar indexing on task Task (runnable) @0x000071040fa00010.
│ Invocation resulted in scalar indexing of a TracedRArray.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on device, but very slowly on the CPU,
│ and require expensive copies and synchronization each time and therefore should be avoided.
└ @ Reactant ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:53
ERROR: StackOverflowError:
Stacktrace:
     [1] mlirOperationCreate
       @ ~/.julia/packages/Reactant/9aeyF/src/mlir/libMLIR_h.jl:985 [inlined]
     [2] create_operation(name::String, loc::Reactant.MLIR.IR.Location; results::Vector{…}, operands::Vector{…}, owned_regions::Vector{…}, successors::Vector{…}, attributes::Vector{…}, result_inference::Bool)
       @ Reactant.MLIR.IR ~/.julia/packages/Reactant/9aeyF/src/mlir/IR/Operation.jl:315
     [3] create_operation
       @ ~/.julia/packages/Reactant/9aeyF/src/mlir/IR/Operation.jl:273 [inlined]
     [4] broadcast_in_dim(operand::Reactant.MLIR.IR.Value; result_0::Reactant.MLIR.IR.Type, broadcast_dimensions::Reactant.MLIR.IR.Attribute, location::Reactant.MLIR.IR.Location)
       @ Reactant.MLIR.Dialects.stablehlo ~/.julia/artifacts/6c9478736605d4915b52b83803cbe351dc32bd48/StableHLO.inc.jl:430
     [5] broadcast_in_dim
       @ ~/.julia/artifacts/6c9478736605d4915b52b83803cbe351dc32bd48/StableHLO.inc.jl:423 [inlined]
     [6] broadcast_to_size_internal(x::Reactant.TracedRArray{ComplexF64, 0}, rsize::Base.Generator{Base.Iterators.Zip{Tuple{Tuple{}, Tuple{}}}, Reactant.var"#96#97"})
       @ Reactant ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:737
     [7] broadcast_to_size
       @ ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:714 [inlined]
     [8] broadcast_to_size
       @ ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:719 [inlined]
     [9] #98
       @ ./none:0 [inlined]
    [10] iterate
       @ ./generator.jl:47 [inlined]
    [11] _copyto!(dest::Reactant.TracedRArray{ComplexF64, 0}, bc::Base.Broadcast.Broadcasted{Nothing, Nothing, typeof(conj), Tuple{Reactant.TracedRArray{ComplexF64, 0}}})
       @ Reactant ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:759
    [12] copyto!(dest::Reactant.TracedRArray{ComplexF64, 0}, bc::Base.Broadcast.Broadcasted{Nothing, Nothing, typeof(conj), Tuple{Reactant.TracedRArray{ComplexF64, 0}}})
       @ Reactant ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:681
    [13] copyto!(dest::Reactant.TracedRArray{ComplexF64, 0}, bc::Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{0}, Nothing, typeof(conj), Tuple{Reactant.TracedRArray{ComplexF64, 0}}})
       @ Base.Broadcast ./broadcast.jl:967
    [14] copy(bc::Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{0}, Nothing, typeof(conj), Tuple{Reactant.TracedRArray{ComplexF64, 0}}})
       @ Reactant ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:641
    [15] materialize
       @ ./broadcast.jl:903 [inlined]
    [16] broadcast_preserving_zero_d(f::Function, As::Reactant.TracedRArray{ComplexF64, 0})
       @ Base.Broadcast ./broadcast.jl:892
    [17] conj(A::Reactant.TracedRArray{ComplexF64, 0})
       @ Base ./abstractarraymath.jl:145
    [18] elem_apply(f::Function, args::Reactant.TracedRArray{ComplexF64, 0})
       @ Reactant ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:354
--- the last 8 lines are repeated 6663 more times ---
 [53323] _copyto!(dest::Reactant.TracedRArray{ComplexF64, 0}, bc::Base.Broadcast.Broadcasted{Nothing, Nothing, typeof(conj), Tuple{Reactant.TracedRArray{ComplexF64, 0}}})
       @ Reactant ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:759
 [53324] copyto!(dest::Reactant.TracedRArray{ComplexF64, 0}, bc::Base.Broadcast.Broadcasted{Nothing, Nothing, typeof(conj), Tuple{Reactant.TracedRArray{ComplexF64, 0}}})
       @ Reactant ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:681
 [53325] copyto!(dest::Reactant.TracedRArray{ComplexF64, 0}, bc::Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{0}, Nothing, typeof(conj), Tuple{Reactant.TracedRArray{ComplexF64, 0}}})
       @ Base.Broadcast ./broadcast.jl:967
 [53326] copy(bc::Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{0}, Nothing, typeof(conj), Tuple{Reactant.TracedRArray{ComplexF64, 0}}})
       @ Reactant ~/.julia/packages/Reactant/9aeyF/src/TracedRArray.jl:641
 [53327] materialize
       @ ./broadcast.jl:903 [inlined]
 [53328] broadcast_preserving_zero_d(f::Function, As::Reactant.TracedRArray{ComplexF64, 0})
       @ Base.Broadcast ./broadcast.jl:892
Some type information was truncated. Use `show(err)` to see complete types.

Do know what may have caused this?

Thank you!

mofeing commented 2 weeks ago

I can't locate any line from Tenet or YaoBlocks in that stacktrace. Can you post the full non-truncated stacktrace by doing show(err)?

mofeing commented 2 weeks ago

The error seems to be that we haven't overloaded conj for TracedRArray