JuliaML / LossFunctions.jl

Julia package of loss functions for machine learning.
https://juliaml.github.io/LossFunctions.jl/stable
Other
148 stars 34 forks source link

Very slow loss aggregation #172

Closed MilesCranmer closed 11 months ago

MilesCranmer commented 11 months ago

The mean and sum implementations in this package are extremely slow as they rely on generators. Could this be changed to a faster implementation? For example, here is how the current implementation would compute sum-squared-error:

square(x) = x ^ 2

# example of generator-based approach:
function sse(outputs, targets)
    sum(square(ŷ - y) for (ŷ, y) in zip(outputs, targets))
end

(i.e., like this code)

which gives us the following time:

julia> @btime sse(outputs, targets) setup=(outputs=randn(100_000); targets=randn(100_000))
  92.833 μs (0 allocations: 0 bytes)

but if we change this to an approach using sum(<function>, <indices>), it's much faster:

function sse2(outputs, targets)
    sum(i -> square(outputs[i] - targets[i]), eachindex(outputs, targets))
end
julia> @btime sse2(outputs, targets) setup=(outputs=randn(100_000); targets=randn(100_000))
  26.708 μs (0 allocations: 0 bytes)

which is a 3.5x speedup.

Could this be implemented as the default loss calculation? I thought this was the method that used to be used. Perhaps it got changed in the recent refactoring?

juliohm commented 11 months ago

I don't remember using this feature anywhere else in downstream packages. If you can submit a PR dropping support for generators, we can review and merge it.

Em qui., 24 de ago. de 2023 04:31, Miles Cranmer @.***> escreveu:

The mean and sum implementations in this package are extremely slow as they rely on generators. Could this be changed to a faster implementation? For example, here is how the current implementation would compute sum-squared-error:

square(x) = x ^ 2

example of generator-based approach:function sse(outputs, targets)

sum(square(ŷ - y) for (ŷ, y) in zip(outputs, targets))end

(i.e., like this code https://github.com/JuliaML/LossFunctions.jl/blob/7318c582874988fcc325464831aa2ce94eb36ff2/src/losses.jl#L37-L39 )

which gives us the following time:

julia> @btime sse(outputs, targets) setup=(outputs=randn(100_000); targets=randn(100_000)) 92.833 μs (0 allocations: 0 bytes)

but if we change this to an approach using sum(, ), it's much faster:

function sse2(outputs, targets) sum(i -> square(outputs[i] - targets[i]), eachindex(outputs, targets))end

julia> @btime sse2(outputs, targets) setup=(outputs=randn(100_000); targets=randn(100_000)) 26.708 μs (0 allocations: 0 bytes)

which is a 3.5x speedup.

Could this be implemented as the default loss calculation? I thought this was the method that used to be used. Perhaps it got changed in the recent refactoring?

— Reply to this email directly, view it on GitHub https://github.com/JuliaML/LossFunctions.jl/issues/172, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZQW3KI56D2NUGNSS7FNQLXW37NRANCNFSM6AAAAAA34SKW4E . You are receiving this because you are subscribed to this thread.Message ID: @.***>

MilesCranmer commented 11 months ago

I don't remember using this feature anywhere else in downstream packages.

Sorry I don't think I was clear. This is the main sum function in LossFunctions.jl which is the main interface to the package. It seems to be using a generator (e.g., sum(f(x) for x in X)) for summing, which is very slow:

https://github.com/JuliaML/LossFunctions.jl/blob/7318c582874988fcc325464831aa2ce94eb36ff2/src/losses.jl#L37-L39

If you can submit a PR dropping support for generators, we can review and merge it.

Sure!

MilesCranmer commented 11 months ago

Implemented in https://github.com/JuliaML/LossFunctions.jl/pull/173

juliohm commented 11 months ago

Yes, i meant using the version with generators in downstream packages that depend on LossFunctions.jl

Em qui., 24 de ago. de 2023 05:49, Miles Cranmer @.***> escreveu:

Implemented in #173 https://github.com/JuliaML/LossFunctions.jl/pull/173

— Reply to this email directly, view it on GitHub https://github.com/JuliaML/LossFunctions.jl/issues/172#issuecomment-1691274625, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZQW3OTKEZIWOCK55VHWWTXW4IRZANCNFSM6AAAAAA34SKW4E . You are receiving this because you commented.Message ID: @.***>

MilesCranmer commented 11 months ago

Got it, thanks!