JuliaML / LossFunctions.jl

Julia package of loss functions for machine learning.
https://juliaml.github.io/LossFunctions.jl/stable
Other
147 stars 33 forks source link

Refactoring of codebase #126

Closed juliohm closed 1 month ago

juliohm commented 4 years ago

Dear all,

I would like to propose a major refactoring of the codebase to simplify future additions and generalizations, and to facilitate future contributions. I think we are overusing macros in the package with the only purpose of sharing docstrings among various losses, and we could eliminate this entry barrier to potential contributors.

In particular, I would like to suggest a few modifications here, and ask for your approval before I start submitting PRs.

Suggestions of improvement

  1. Can we get rid of the value_fun, deriv_fun, deriv2_fun, value_deriv_fun functionality? I understand that these functions were created in the past because the language didn't have efficient lambdas and closures, etc. Moving forward, I think we could stick to a single interface for evaluating losses value, deriv, and deriv2 where the last two functions could have fallback implementations via auto-diff when the user only implements the value function.

  2. Similarly, can we get rid of the following functor syntax in the top source file, and stick to the single API defined above?

# allow using some special losses as function
(loss::ScaledSupervisedLoss)(args...) = value(loss, args...)
(loss::WeightedBinaryLoss)(args...)   = value(loss, args...)

# allow using SupervisedLoss as function
for T in filter(isconcretetype, subtypes(SupervisedLoss))
    @eval (loss::$T)(args...) = value(loss, args...)
end

# allow using MarginLoss and DistanceLoss as function
for T in union(subtypes(DistanceLoss), subtypes(MarginLoss))
    @eval (loss::$T)(args...) = value(loss, args...)
end
  1. Can we get rid of the internal (no-exported) types Deriv and Deriv2? I understand that they are only used internally for plotting derivatives in src/supervised/io.jl.

  2. Can we simplify the test suite? Currently, it seems to be testing the same functionality with hundreds of numbers giving the illusion of good coverage and making any run of the tests take forever when tests fail (IO bottlenecks).

  3. Can we simplify the loop in src/supervised/supervised.jl that loops over value, deriv and deriv2? In particular, I am curious if we could only define (without metaprogramming) the aggregation of value, and then rely on auto-diff to compute deriv and deriv2? This is a modification that we need to think more carefully, but that could simplify the codebase tremendously. If auto-diff does not work for all losses, we can always provide a specific implementation to overwrite the auto-diff defaults. My question is, can auto-diff be performed at compile time? Do we get any slowdown if we follow this design?

I will start working on separate PRs for items (1) (2) (3) and (5). I need your input before I can start working on (5).

joshday commented 4 years ago
  1. I'm okay dropping any _fun function, but there are some micro-optimizations with things like value_deriv that can help you avoid calculating the same thing twice.

  2. I think starting with Julia 1.3, we can simplify all of that to

    (l::loss)(args...) = value(l, args...)
  3. As long as we still have a way to create those plots, that's fine.

  4. I mostly like how comprehensive the tests are, especially with respect to checking type stability. I wouldn't call the good coverage an illusion. I do wish the tests ran faster, though.

  5. Looping over @eval is always hard to understand. A simplification there would be good with me. However, autodiff seems an unnecessary complication when we are already working with pretty simple, cheap-to-calculate derivatives.

(edited to match numbering above)

juliohm commented 4 years ago

Thank you @joshday , I've updated the list in the original thread, so your bullet (4) is now (5). Sorry for the late change. I will start working on the simpler ones, but the new item (4) is already making it difficult to modify the codebase. I will start with it.

juliohm commented 4 years ago

Suggestions (1) and (2) have been addressed. Can we trigger a new minor release? What should we do in terms of updating deprecated.jl? Appreciate if you can guide on better practices.

CarloLucibello commented 4 years ago

Other suggestions:

joshday commented 4 years ago

@juliohm Looks like you removed ScaledDistanceLoss in favor of Scaled. šŸ‘ on the change, but can you add back ScaledDistanceLoss and properly deprecated it?

There are some downstream breakages in JuliaDB.

juliohm commented 4 years ago

@joshday of course, I will submit a PR in LossFunctions.jl with a deprecation warning, and will submit a patch release. Thanks for sharing the issue.

juliohm commented 4 years ago

Dear all, I would like to start another round of simplifications in the project. In particular, I would like to suggest the removal of the bang version of value!, deriv! and deriv2! in favor of the dot syntax idiom:

y = Vector{Float64}(undef, 100)
for I in 1:1000
  y .= f(rand(100)) # the dot sets the values of y in place
end

So there is no need to define a f! as we do here. Even if there are corner cases where the dot syntax fails, I think we don't have many use cases for the bang versions aforementioned. Most users will use an aggregation of the losses in a tight training loop, or will post-process the individual losses for all observations, in which case the allocating versions seem fine.

Please let me know what do you think about this proposal, I can start working on it right away after you share your point of view.

Best,

joshday commented 4 years ago

It's such a small amount of code to manage and you can already use broadcasting; I'd prefer to keep them.

That being said, I'd love to see the @eval craziness that currently implements them go away.

juliohm commented 4 years ago

@joshday I'd say that the amount of code is one possible metric. In my opinion having a minimum exposed API is more important. It can facilitate future simplifications and maintenance in the long run. I don't see practical use cases for the bang versions, so if you can clarify a use case that justifies them, I'd be happy to learn. Otherwise, please let's remove it and keep things simple.

juliohm commented 2 years ago

Coming back to this issue...

We now have a self-contained version of LossFunctions.jl that does not depend on the names in LearnBase.jl šŸ™šŸ½ This means that we can sort out this cleanup internally in the package as we wish and then release a new breaking release regarding the ObsDim and AggMode business, which I also find very annoying. The code in LossFunctions.jl is unnecessarily complicated because of these submodules.

juliohm commented 1 year ago

Coming back to this issue...

I would like to drop all features related to ObsDim. In particular, I would like to refactor the code so that it takes datasets as iterables of observations. Currently we have to specify the observation dimension in n-dimensional arrays and this is a real mess, specially because most loss function definitions take scalars as inputs.

I will try to find some time to work on this cleanup during the next weeks. Please let me know if you have any objections. The idea is to be able to write something like value(L2DistLoss(), iter1, iter2) where these iterables may contain abstract vectors, tuples, named tuples or any 1-dimensional object.

juliohm commented 1 year ago

Finished the cleanup of ObsDim. Next I plan to revert the order of arguments (target, output) to match the order of other ecosystems such as Flux.jl, which use (output, target).

juliohm commented 1 year ago

Finished the reordering of arguments to match other ecosystems. Will continue with other small improvements in the master branch before preparing a new breaking release.

juliohm commented 1 year ago

Next planned change is to remove the value function in favor of the functor interface.

juliohm commented 1 year ago

Removed value function. Preparing the next breaking release.

juliohm commented 1 month ago

We concluded this refactorization months ago. Closing the issue.