JuliaDecisionFocusedLearning / InferOpt.jl

Combinatorial optimization layers for machine learning pipelines
https://juliadecisionfocusedlearning.github.io/InferOpt.jl/
MIT License
114 stars 4 forks source link

Fast differentiable sorting and ranking #88

Closed BatyLeo closed 11 months ago

BatyLeo commented 1 year ago

Following this discourse post, here is a Julia implementation of Fast differentiable sorting and ranking (see here for the original Python implementation).

This includes:


Reproducing figures from the paper:

using InferOpt, Plots, Zygote
plot(x -> sort([0.0, x, 1.0, 2.0]; rev=true)[2], label="Sort")
plot!(x -> soft_sort_l2([0.0, x, 1.0, 2.0]; rev=true)[2], label="Soft sort l2")
plot!(y -> gradient(x -> soft_sort_l2([0.0, x, 1.0, 2.0]; ε=1.0, rev=true)[2], y)[1], label="Soft sort l2 derivative")

plot

plot(x -> sort([0.0, x, 1.0, 2.0]; rev=true)[2], label="Sort")
plot!(x -> soft_sort_kl([0.0, x, 1.0, 2.0]; rev=true)[2], label="Soft sort kl")
plot!(y -> gradient(x -> soft_sort_kl([0.0, x, 1.0, 2.0]; ε=1.0, rev=true)[2], y)[1], label="Soft sort kl derivative")

plot

plot(x -> ranking([x, 3.0, 1.0, 2.0]; rev=true)[1], label="Ranking")
plot!(x -> soft_rank_l2([x, 3.0, 1.0, 2.0]; rev=true)[1], label="Soft rank l2")
plot!(y -> gradient(x -> soft_rank_l2([x, 3.0, 1.0, 2.0]; ε=1.0, rev=true)[1], y)[1], label="Soft rank l2 derivative")

plot

plot(x -> ranking([x, 3.0, 1.0, 2.0]; rev=true)[1], label="Ranking")
plot!(x -> soft_rank_kl([x, 3.0, 1.0, 2.0]; rev=true)[1], label="Soft rank kl")
plot!(y -> gradient(x -> soft_rank_kl([x, 3.0, 1.0, 2.0]; ε=1.0, rev=true)[1], y)[1], label="Soft rank kl derivative")

plot

codecov-commenter commented 1 year ago

Codecov Report

Patch coverage: 91.66% and project coverage change: +1.95% :tada:

Comparison is base (f9d8dab) 85.34% compared to head (ff85e4a) 87.29%.

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #88 +/- ## ========================================== + Coverage 85.34% 87.29% +1.95% ========================================== Files 18 23 +5 Lines 389 559 +170 ========================================== + Hits 332 488 +156 - Misses 57 71 +14 ``` | [Files Changed](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/88?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None) | Coverage Δ | | |---|---|---| | [src/InferOpt.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/88?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL0luZmVyT3B0Lmps) | `100.00% <ø> (ø)` | | | [src/regularized/soft\_rank.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/88?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL3JlZ3VsYXJpemVkL3NvZnRfcmFuay5qbA==) | `60.00% <60.00%> (ø)` | | | [src/utils/isotonic\_regression/isotonic\_kl.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/88?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL3V0aWxzL2lzb3RvbmljX3JlZ3Jlc3Npb24vaXNvdG9uaWNfa2wuamw=) | `100.00% <100.00%> (ø)` | | | [src/utils/isotonic\_regression/isotonic\_l2.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/88?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL3V0aWxzL2lzb3RvbmljX3JlZ3Jlc3Npb24vaXNvdG9uaWNfbDIuamw=) | `100.00% <100.00%> (ø)` | | | [src/utils/isotonic\_regression/projection.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/88?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL3V0aWxzL2lzb3RvbmljX3JlZ3Jlc3Npb24vcHJvamVjdGlvbi5qbA==) | `100.00% <100.00%> (ø)` | | ... and [1 file with indirect coverage changes](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/88/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.