JuliaDiff / ReverseDiff.jl

Reverse Mode Automatic Differentiation for Julia
Other
348 stars 57 forks source link

Error when using scalar vs. vector to operate on tracked inupt #214

Open slwu89 opened 1 year ago

slwu89 commented 1 year ago

Hi ReverseDiff team,

I've found an error that comes up when an object used in a function being differentiated is the result of a tracked array and another argument (not the one we are finding the derivative with respect to) when the other argument is a scalar. I understand my explanation is a bit imprecise, a minimum reproducible example is below. In the line labeled "doesn't work" the arg c is treated as a scalar. In the line labeled "works" it is treated as a length one vector. The object d is used in further computation; when c is treated as a scalar this leads to a method error when it (an element of the resulting object d) is used in the line aa*sin(dd).

[EDIT]: I'm on Julia 1.8.3 and the version of ReverseDiff.jl on the Julia general repo (v1.14.4).

using ReverseDiff
using Distributions

# a,b are vectors; want Hessian w.r.t a
# c is "scalar" arg
function mwe(a,b,c)
    d = @. a/(b+c)
    e = map(a,b,d) do aa, bb, dd
        if dd < bb
            exp(bb)
        else
            aa*sin(dd)
        end
    end
    return sum(e)
end

a = rand(Normal(),5)
b = rand(Normal(),5)
c = 50.0

ReverseDiff.hessian(x -> mwe(x, b, c), a)  # doesn't work
ReverseDiff.hessian(x -> mwe(x, b, [c]), a)  # works
kishore-nori commented 1 year ago

I encountered a similar error. The following equivalent (basically replacing with for loop to mimic what .@ macro is "supposed to do") and it works, (I shortened the above MWE)

function mwe(a,c)
  # d = @. a/(a + c) # this doesn't work
  # d = @. 1.0/(a + c) # even this doesn't work 
  # d = @. a + c # but this works 
  # d = @. 1.0/a # this also works 
  # but below works 
  d = similar(a)
  @simd for i in eachindex(a)
    @inbounds d[i] = a[i]/(a[i] + c)
  end
  e = map(a,d) do aa, dd
        aa*dd
  end
  return sum(e)
end

a = rand(Normal(),5)
c = 50.0

mwe(a,c)

ReverseDiff.hessian(x -> mwe(x, c), a) 
# works! (haven't checked the correctness, but structure of the matrix is alright.)

More interestingly the below works!!

function mwe(a,c)
  # d = @. a/(a + c) # doesn't work

  # works 
  d = similar(a)
  @. d = a/(a + c)

  e = map(a,d) do aa, dd
        aa*dd
  end
  return sum(e)
end

a = rand(Normal(),5)
c = 50.0

mwe(a,c)

ReverseDiff.hessian(x -> mwe(x, c), a)