EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
460 stars 66 forks source link

Missing Rmath derivatives #1620

Open mhauru opened 4 months ago

mhauru commented 4 months ago

A few Distributions.jl distributions rely on Rmath.jl, for which derivatives seem to not be defined.

module MWE

using Distributions: NoncentralBeta, logpdf
using Distributions: NoncentralChisq, NoncentralF, NoncentralT
using Enzyme

f(x) = logpdf(NoncentralBeta(1.0, 1.0, 1.0), x[1])
g(x) = logpdf(NoncentralChisq(1.0, 1.0), x[1])
h(x) = logpdf(NoncentralF(1.0, 1.0, 1.0), x[1])
i(x) = logpdf(NoncentralT(1.0, 1.0), x[1])
Enzyme.gradient(Enzyme.Forward, f, [0.5])
Enzyme.gradient(Enzyme.Forward, g, [0.5])
Enzyme.gradient(Enzyme.Forward, h, [0.5])
Enzyme.gradient(Enzyme.Forward, i, [0.5])

end

For NoncentralBeta, output:

Current scope:
; Function Attrs: mustprogress willreturn
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_f_8695({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="4757723584" "enzymejl_parmtype_ref"="2" %0) local_unnamed_addr #4 !dbg !47 {
top:
  %1 = call {}*** @julia.get_pgcstack() #5
  %ptls_field3 = getelementptr inbounds {}**, {}*** %1, i64 2
  %2 = bitcast {}*** %ptls_field3 to i64***
  %ptls_load45 = load i64**, i64*** %2, align 8, !tbaa !8
  %3 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
  %safepoint = load i64*, i64** %3, align 8, !tbaa !12
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #5, !dbg !48
  fence syncscope("singlethread") seq_cst
  %4 = addrspacecast {} addrspace(10)* %0 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !49
  %arraylen_ptr = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %4, i64 0, i32 1, !dbg !49
  %arraylen = load i64, i64 addrspace(11)* %arraylen_ptr, align 8, !dbg !49, !tbaa !18, !range !21, !alias.scope !22, !noalias !25
  %inbounds.not = icmp eq i64 %arraylen, 0, !dbg !49
  br i1 %inbounds.not, label %oob, label %idxend, !dbg !49

oob:                                              ; preds = %top
  %errorbox = alloca i64, align 8, !dbg !49
  store i64 1, i64* %errorbox, align 8, !dbg !49, !noalias !50
  %5 = addrspacecast {} addrspace(10)* %0 to {} addrspace(12)*, !dbg !49
  call void @ijl_bounds_error_ints({} addrspace(12)* noundef %5, i64* noundef nonnull align 8 %errorbox, i64 noundef 1) #6, !dbg !49
  unreachable, !dbg !49

idxend:                                           ; preds = %top
  %6 = addrspacecast {} addrspace(10)* %0 to double addrspace(13)* addrspace(11)*, !dbg !49
  %arrayptr6 = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %6, align 16, !dbg !49, !tbaa !33, !alias.scope !53, !noalias !25, !nonnull !7
  %arrayref = load double, double addrspace(13)* %arrayptr6, align 8, !dbg !49, !tbaa !36, !alias.scope !39, !noalias !40
  %7 = call double @dnbeta(double %arrayref, double noundef 1.000000e+00, double noundef 1.000000e+00, double noundef 1.000000e+00, i32 noundef 1) #5, !dbg !54
  ret double %7, !dbg !54
}

No forward mode derivative found for dnbeta
 at context:   %7 = call double @dnbeta(double %arrayref, double noundef 1.000000e+00, double noundef 1.000000e+00, double noundef 1.000000e+00, i32 noundef 1) #5, !dbg !41

Stacktrace:
 [1] nbetalogpdf
   @ ~/.julia/packages/StatsFuns/mrf0e/src/rmath.jl:77
 [2] logpdf
   @ ~/.julia/packages/Distributions/ji8PW/src/univariates.jl:645
 [3] f
   @ ~/projects/Enzyme-mwes/callinst_metadata/noncentralbeta.jl:7

The others produce similar outputs but instead of dnbeta for dnchisq, dnf, and dnt. Reverse mode likewise missing.

wsmoses commented 4 months ago

Yeah unfortunately I've never seen those functions before so I'm not quite sure what they intend to compute.

Know any docs for offhand. Also would be happy to show you how to add internal support within Enzyme for function derivatives, if interested

yebai commented 4 months ago

Related: https://github.com/compintell/Tapir.jl/issues/31