Closed biona001 closed 4 months ago
Well, I took out the type annotations on the objective functions, then both objective1
, and objective3
worked. objective2
still throws the same error, but I guess it's not that important now. Closing this
function mixture_loglikelihood1(params::AbstractVector, data::AbstractVector)
K = length(params) ÷ 3
weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]
mat = normal_pdf.(data, means', stds' .^2) # (N, K)
sum(mat .* weights', dims=2) .|> log |> sum
end
function mixture_loglikelihood2(params::AbstractVector, data::AbstractVector)
K = length(params) ÷ 3
weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]
mat = normal_pdf.(data, means', stds' .^2) # (N, K)
obj_true = sum(
sum(
weight * normal_pdf(x, mean, std^2)
for (weight, mean, std) in zip(weights, means, stds)
) |> log
for x in data
)
end
function mixture_loglikelihood3(params::AbstractVector, data::AbstractVector)
K = length(params) ÷ 3
weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]
mat = normal_pdf.(data, means', stds' .^2) # (N, K)
# objective re-written by me
obj = zero(eltype(mat))
for x in data
obj_i = zero(eltype(mat))
for (weight, mean, std) in zip(weights, means, stds)
obj_i += weight * normal_pdf(x, mean, std^2)
end
obj += log(obj_i)
end
return obj
end
objective1 = params -> mixture_loglikelihood1(params, data)
objective2 = params -> mixture_loglikelihood2(params, data)
objective3 = params -> mixture_loglikelihood3(params, data)
Reopening this since the bounds error should never occur for sure.
By the way, with N_SAMPLES = 10000
and N_COMPONENTS = 5
, I'm seeing the following time
So it seems enzymes.jl
is the fastest autodiff package in Julia as far as I can tell, but still ~30% slower than JAX (python).
Bounds error now fixed, closing
Hi community,
I am trying to include
Enzyme.jl
to this benchmark which compares the speed ofForwardDiff.jl
,ReverseDiff.jl
,Symbolics.jl
,Zygote.jl
to that ofJAX
for computing the gradient of a relatively simple loglikelihood function. My hope is thatEnzyme.jl
would be as fast as JAX or at least faster than all other AD packages for this problem. The actual problem I'm trying to differentiate is a much more complicated loglikelihood function. I ran the code below onThe issue is I am getting a 2 different errors depending on how I write the objective (none of them work). MWE:
Enzyme.jl on objective1:
Enzyme.jl on objective2:
Including
Enzyme.API.runtimeActivity!(true)
does not change anything.Enzyme.jl on objective3:
Any tips/suggestions would be highly appreciated.