tshort / StaticCompiler.jl

Compiles Julia code to a standalone library (experimental)
Other
488 stars 31 forks source link

Automatic zeroing first `N-1` arguments via `*= 0.0`? #152

Open chriselrod opened 4 months ago

chriselrod commented 4 months ago

I was trying to make a minimal example of a different issue and noticed that it is always 0-ing the first argument, in this case as if by *= 0.0.

using LLVM, StaticCompiler

f(d, s) = -0.5abs2(d / s) - sqrt(s)

ccall_mod = StaticCompiler.native_llvm_module(f, (Float64, Float64); demangle=true)

function codegen!(mod::LLVM.Module, name)
  param_types = [LLVM.DoubleType(), LLVM.DoubleType()]
  ret_type = LLVM.DoubleType()

  tm = LLVM.JITTargetMachine(Sys.MACHINE, Sys.CPU_NAME)
  triple!(mod, triple(tm))

  ft = LLVM.FunctionType(ret_type, param_types)
  fn = LLVM.Function(mod, name, ft)

  # generate IR
  @dispose builder = IRBuilder() begin
    entry = BasicBlock(fn, "entry")
    position!(builder, entry)

    ccalltype = LLVM.FunctionType(LLVM.DoubleType(), [LLVM.DoubleType(), LLVM.DoubleType()])
    ccallname = StaticCompiler.fix_name(f)
    fnccall = LLVM.Function(mod, ccallname, ccalltype)
    LLVM.linkage!(fnccall, LLVM.API.LLVMExternalLinkage)
    params = [parameters(fn)[1], parameters(fn)[2]]
    tmp = call!(builder, ccalltype, fnccall, params)
    ret!(builder, tmp)
  end

  verify(mod)
  println(string(mod))
  mod
end

function compile()
  tm = LLVM.JITTargetMachine(Sys.MACHINE, Sys.CPU_NAME)
  jit = LLJIT(; tm)
  @dispose ts_ctx = ThreadSafeContext() begin
    ts_mod = LLVM.ThreadSafeModule("ccall_caller")
    name = "ccall_example"
    ts_mod() do mod
      codegen!(mod, name)
    end
    ts_mod() do tsmod
      LLVM.link!(tsmod, copy(ccall_mod))
    end

    jd = JITDylib(jit)
    add!(jit, jd, ts_mod)
    if true
      prefix = LLVM.get_prefix(jit)
      dg = LLVM.CreateDynamicLibrarySearchGeneratorForProcess(prefix)
      LLVM.add!(jd, dg)
    end
    let addr = pointer(lookup(jit, name))
      (d, s) -> ccall(addr, Float64, (Float64, Float64), d, s)
    end
  end
end

f2 = compile()

f(0.0, 0.3)
f2(0.0, 0.3)
f(1.1, 0.3)
f2(1.1, 0.3)
f(-0.3, 0.3)
f2(-0.3, 0.3)
f(NaN, 0.3)
f2(NaN, 0.3)
f(Inf, 0.3)
f2(Inf, 0.3)
f(-Inf, 0.3)
f2(-Inf, 0.3)

I get

julia> versioninfo()
Julia Version 1.10.1
Commit 7790d6f064 (2024-02-13 20:41 UTC)
Platform Info:
  OS: Linux (x86_64-redhat-linux)
  CPU: 36 × Intel(R) Core(TM) i9-7980XE CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake-avx512)
Threads: 36 default, 0 interactive, 18 GC (on 36 virtual cores)
Environment:
  JULIA_PATH = @.
  LD_LIBRARY_PATH = /usr/local/lib/x86_64-unknown-linux-gnu/:/usr/local/lib/:/usr/local/lib/x86_64-unknown-linux-gnu/:/usr/local/lib/
  JULIA_NUM_THREADS = 36
  LD_UN_PATH = /usr/local/lib/x86_64-unknown-linux-gnu/:/usr/local/lib/

julia> f(0.0, 0.3)
-0.5477225575051661

julia> f2(0.0, 0.3)
-0.5477225575051661

julia> f(1.1, 0.3)
-7.269944779727389

julia> f2(1.1, 0.3)
-0.5477225575051661

julia> f(-0.3, 0.3)
-1.047722557505166

julia> f2(-0.3, 0.3)
-0.5477225575051661

julia> f(NaN, 0.3)
NaN

julia> f2(NaN, 0.3)
NaN

julia> f(Inf, 0.3)
-Inf

julia> f2(Inf, 0.3)
NaN

julia> f(-Inf, 0.3)
-Inf

julia> f2(-Inf, 0.3)
NaN

(@sctest) pkg> st -m LLVM StaticCompiler GPUCompiler
Status `~/.julia/environments/jsc0/Manifest.toml`
  [61eb1bfa] GPUCompiler v0.25.0
  [929cbde3] LLVM v6.5.0
  [81625895] StaticCompiler v0.6.2

However, the LLVM modules look correct:

define double @f(double %0, double %1) local_unnamed_addr #2 {
top:
  %2 = fcmp uge double %1, 0.000000e+00
  br i1 %2, label %L7, label %L6

L6:                                               ; preds = %top
  call fastcc void @julia__throw_complex_domainerror_1080()
  br label %L7

L7:                                               ; preds = %L6, %top
  %3 = fdiv double %0, %1
  %4 = fmul double %3, %3
  %5 = fmul double %4, -5.000000e-01
  %6 = call double @llvm.sqrt.f64(double %1)
  %7 = fsub double %5, %6
  ret double %7
}

and

define double @ccall_example(double %0, double %1) {
entry:
  %2 = call double @f(double %0, double %1)
  ret double %2
}

So perhaps this is a bug with the ORC JIT, ccall, some silly user error above that I am not seeing?

I saw similar with 3 args, where again all but the last were zeroed.

brenhinkeller commented 4 months ago

That is.... odd!

brenhinkeller commented 4 months ago

Ah, is it just because of the literal pointers not being reliable?

chriselrod commented 4 months ago

Ah, is it just because of the literal pointers not being reliable?

I don't think so, but I have no idea what's going on.