EnzymeAD / Reactant.jl

MIT License
70 stars 7 forks source link

Reactant fails to compile a function accepting complex parameters #235

Closed Todorbsc closed 3 weeks ago

Todorbsc commented 4 weeks ago

CC @mofeing

In the following simple example, I tried to compile (@code_hlo) a loss function that computes the distance (element-wise) between two arrays: params and expected. In the first case, the compilation results to an error when using complex element types. However, in the second case, there are no errors when the arrays element type are real numbers.

julia> using Enzyme

julia> using Reactant
AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ")

julia> using Adapt

julia> N = 10
10

julia> params = rand(ComplexF64, N)
10-element Vector{ComplexF64}:
 0.12084727623502733 + 0.8378834373925023im
  0.5419946891541254 + 0.6053052086481233im
  0.2565473286949522 + 0.8431469031477166im
    0.70662642814057 + 0.6980685781138029im
  0.9987939974878999 + 0.17192767005946274im
  0.4165299259757198 + 0.25379805337185957im
 0.11717742496634787 + 0.265571005045361im
  0.7710465192286735 + 0.8409264098078348im
  0.9924029253223047 + 0.43294561767846096im
  0.6572012110110073 + 0.9084338310312545im

julia> expected = rand(ComplexF64, N)
10-element Vector{ComplexF64}:
 0.9447681271694002 + 0.4096624268529826im
 0.5716694247178182 + 0.4528185471977071im
 0.6780196017731785 + 0.9039003321208116im
 0.5023224456018169 + 0.7798185220036048im
 0.5792300530436798 + 0.387493494572811im
 0.3068605176293002 + 0.7912510612324968im
 0.7809533996226177 + 0.05725168150432591im
 0.9422271928849323 + 0.811064536254466im
 0.7181801909588854 + 0.42381843403608377im
 0.8468104487119836 + 0.8343389380025642im

julia> function f(params, expected)
           return sum(abs.(expected - params))
       end
f (generic function with 1 method)

julia> params′ = adapt(ConcreteRArray, params)
10-element ConcreteRArray{ComplexF64, 1}:
 0.12084727623502733 + 0.8378834373925023im
  0.5419946891541254 + 0.6053052086481233im
  0.2565473286949522 + 0.8431469031477166im
    0.70662642814057 + 0.6980685781138029im
  0.9987939974878999 + 0.17192767005946274im
  0.4165299259757198 + 0.25379805337185957im
 0.11717742496634787 + 0.265571005045361im
  0.7710465192286735 + 0.8409264098078348im
  0.9924029253223047 + 0.43294561767846096im
  0.6572012110110073 + 0.9084338310312545im

julia> expected′ = adapt(ConcreteRArray, expected)
10-element ConcreteRArray{ComplexF64, 1}:
 0.9447681271694002 + 0.4096624268529826im
 0.5716694247178182 + 0.4528185471977071im
 0.6780196017731785 + 0.9039003321208116im
 0.5023224456018169 + 0.7798185220036048im
 0.5792300530436798 + 0.387493494572811im
 0.3068605176293002 + 0.7912510612324968im
 0.7809533996226177 + 0.05725168150432591im
 0.9422271928849323 + 0.811064536254466im
 0.7181801909588854 + 0.42381843403608377im
 0.8468104487119836 + 0.8343389380025642im

julia> @code_hlo f(params′, expected′)
error: type of return operand 0 ('tensor<f64>') doesn't match function result type ('tensor<complex<f64>>') in function @abs_broadcast_scalar
ERROR: "failed to run pass manager on module"
Stacktrace:
 [1] run!
   @ ~/.julia/packages/Reactant/e7PeE/src/mlir/IR/Pass.jl:70 [inlined]
 [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String)
   @ Reactant.Compiler ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:251
 [3] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Vector{ConcreteRArray{ComplexF64, 1}}; optimize::Bool)
   @ Reactant.Compiler ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:283
 [4] compile_mlir!
   @ ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:265 [inlined]
 [5] #2
   @ ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:260 [inlined]
 [6] context!(f::Reactant.Compiler.var"#2#3"{@Kwargs{optimize::Bool}, typeof(f), Vector{ConcreteRArray{ComplexF64, 1}}}, ctx::Reactant.MLIR.IR.Context)
   @ Reactant.MLIR.IR ~/.julia/packages/Reactant/e7PeE/src/mlir/IR/Context.jl:76
 [7] compile_mlir(f::Function, args::Vector{ConcreteRArray{ComplexF64, 1}}; kwargs::@Kwargs{optimize::Bool})
   @ Reactant.Compiler ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:258
 [8] top-level scope
   @ ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:419

julia> params_real = rand(N)
10-element Vector{Float64}:
 0.7157500469767611
 0.73771994445251
 0.09715354680648502
 0.17490931236530827
 0.02665553154828393
 0.7123407609231756
 0.440286278633386
 0.3922503310867863
 0.5138483522747382
 0.26198736321505955

julia> expected_real = rand(N)
10-element Vector{Float64}:
 0.1338721367036817
 0.8848012771114079
 0.22196944824020126
 0.5278528100723199
 0.0774277299797459
 0.7046647715949552
 0.6004498869799781
 0.9856254783292705
 0.8937784894687109
 0.1956969406579151

julia> params_real′ = adapt(ConcreteRArray, params_real)
10-element ConcreteRArray{Float64, 1}:
 0.7157500469767611
 0.73771994445251
 0.09715354680648502
 0.17490931236530827
 0.02665553154828393
 0.7123407609231756
 0.440286278633386
 0.3922503310867863
 0.5138483522747382
 0.26198736321505955

julia> expected_real′ = adapt(ConcreteRArray, expected_real)
10-element ConcreteRArray{Float64, 1}:
 0.1338721367036817
 0.8848012771114079
 0.22196944824020126
 0.5278528100723199
 0.0774277299797459
 0.7046647715949552
 0.6004498869799781
 0.9856254783292705
 0.8937784894687109
 0.1956969406579151

julia> fR = Reactant.@compile f(params_real′, expected_real′)
Reactant.Compiler.Thunk{Symbol("##f_reactant#225")}()

julia> fR(params_real′, expected_real′)
ConcreteRNumber{Float64}(2.464926145172581)

julia> f(params_real, expected_real)
2.464926145172581
mofeing commented 4 weeks ago

The StableHLO spec says that stablehlo.abs should return a f64 tensor in this case, so I don't know why are inferring a complex<f64> tensor? maybe our implementation of Base.abs(::TracedRArray) is wrong when we set the returning MLIR type?