EnzymeAD / Reactant.jl

MIT License
57 stars 4 forks source link

`XlaRuntimeError` when passing a `Complex` buffer to a compiled function #172

Open mofeing opened 1 week ago

mofeing commented 1 week ago

The MLIR code seems to have been correctly generated, as shown in the example below, but when passing the buffers of complex numbers it just crashes with XlaRuntimeError: Executable expected parameter 0 of size 64 but got buffer with incompatible size 4.

Maybe the primitive_type of Complex is wrong?

julia> using Tenet, Yao, Reactant

julia> @code_hlo Tenet.contract(observable′)
Module:
module attributes {transform.with_named_sequence} {
  func.func @main(%arg0: tensor<2x2xf64>, %arg1: tensor<2x2xf64>) -> tensor<2x2x2x2xf64> {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf64>) -> tensor<2x2xf64>
    %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<2x2xf64>) -> tensor<2x2xf64>
    %2 = stablehlo.einsum %0, %1, config = "AB,CD->ABCD" : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2x2x2xf64>
    %3 = stablehlo.transpose %2, dims = [3, 2, 1, 0] : (tensor<2x2x2x2xf64>) -> tensor<2x2x2x2xf64>
    return %3 : tensor<2x2x2x2xf64>
  }
}

julia> f = @compile Tenet.contract(observable′)
Reactant.Compiler.Thunk{Symbol("##contract_reactant#225")}()

julia> f(observable′)
2×2×2×2 Tensor{Float64, 4, ConcreteRArray{Float64, 4}}:
[:, :, 1, 1] =
 1.0   0.0
 0.0  -1.0

[:, :, 2, 1] =
 0.0   0.0
 0.0  -0.0

[:, :, 1, 2] =
 0.0   0.0
 0.0  -0.0

[:, :, 2, 2] =
 -1.0  -0.0
 -0.0   1.0

julia> observable = Quantum(chain(2, put(1 => Z), put(2 => Z)))
Quantum (inputs=2, outputs=2)

julia> for tensor in tensors(observable)
           replace!(observable, tensor => Tensor(collect(parent(tensor)), inds(tensor)))
           # replace!(observable, tensor => Tensor(real(collect(parent(tensor))), inds(tensor)))
       end

julia> observable′ = adapt(ConcreteRArray, observable)
Quantum (inputs=2, outputs=2)

julia> @code_hlo Tenet.contract(observable′)
Module:
module attributes {transform.with_named_sequence} {
  func.func @main(%arg0: tensor<2x2xcomplex<f64>>, %arg1: tensor<2x2xcomplex<f64>>) -> tensor<2x2x2x2xcomplex<f64>> {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xcomplex<f64>>) -> tensor<2x2xcomplex<f64>>
    %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<2x2xcomplex<f64>>) -> tensor<2x2xcomplex<f64>>
    %2 = stablehlo.einsum %0, %1, config = "AB,CD->ABCD" : (tensor<2x2xcomplex<f64>>, tensor<2x2xcomplex<f64>>) -> tensor<2x2x2x2xcomplex<f64>>
    %3 = stablehlo.transpose %2, dims = [3, 2, 1, 0] : (tensor<2x2x2x2xcomplex<f64>>) -> tensor<2x2x2x2xcomplex<f64>>
    return %3 : tensor<2x2x2x2xcomplex<f64>>
  }
}

julia> f = @compile Tenet.contract(observable′)
Reactant.Compiler.Thunk{Symbol("##contract_reactant#226")}()

julia> f(observable′)
libc++abi: terminating due to uncaught exception of type xla::XlaRuntimeError: INVALID_ARGUMENT: Executable expected parameter 0 of size 64 but got buffer with incompatible size 4

[22501] signal (6): Abort trap: 6
in expression starting at REPL[21]:1
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 58639432 (Pool: 58535606; Big: 103826); GC: 46
fish: Job 1, 'julia' terminated by signal SIGABRT (Abort)
mofeing commented 1 week ago

CC @jofrevalles