JuliaDecisionFocusedLearning / InferOpt.jl

Combinatorial optimization layers for machine learning pipelines
https://juliadecisionfocusedlearning.github.io/InferOpt.jl/
MIT License
113 stars 4 forks source link

Allow more general objectives in combinatorial layers #30

Closed BatyLeo closed 1 year ago

BatyLeo commented 1 year ago

See #25.

codecov-commenter commented 1 year ago

Codecov Report

Patch coverage: 94.91% and project coverage change: +1.13% :tada:

Comparison is base (f9d8dab) 85.34% compared to head (aa22430) 86.48%. Report is 2 commits behind head on main.

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #30 +/- ## ========================================== + Coverage 85.34% 86.48% +1.13% ========================================== Files 18 20 +2 Lines 389 444 +55 ========================================== + Hits 332 384 +52 - Misses 57 60 +3 ``` | [Files Changed](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None) | Coverage Δ | | |---|---|---| | [src/InferOpt.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL0luZmVyT3B0Lmps) | `100.00% <ø> (ø)` | | | [src/perturbed/abstract\_perturbed.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL3BlcnR1cmJlZC9hYnN0cmFjdF9wZXJ0dXJiZWQuamw=) | `86.20% <ø> (ø)` | | | [src/perturbed/additive.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL3BlcnR1cmJlZC9hZGRpdGl2ZS5qbA==) | `88.23% <ø> (ø)` | | | [src/perturbed/multiplicative.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL3BlcnR1cmJlZC9tdWx0aXBsaWNhdGl2ZS5qbA==) | `88.23% <ø> (ø)` | | | [src/perturbed/perturbed\_oracle.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL3BlcnR1cmJlZC9wZXJ0dXJiZWRfb3JhY2xlLmps) | `85.71% <ø> (ø)` | | | [src/utils/generalized\_maximizer.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL3V0aWxzL2dlbmVyYWxpemVkX21heGltaXplci5qbA==) | `66.66% <66.66%> (ø)` | | | [src/imitation/fenchel\_young\_loss.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL2ltaXRhdGlvbi9mZW5jaGVsX3lvdW5nX2xvc3Muamw=) | `92.30% <100.00%> (+4.55%)` | :arrow_up: | | [src/imitation/spoplus\_loss.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL2ltaXRhdGlvbi9zcG9wbHVzX2xvc3Muamw=) | `92.85% <100.00%> (+3.20%)` | :arrow_up: | | [src/imitation/zero\_one\_loss.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL2ltaXRhdGlvbi96ZXJvX29uZV9sb3NzLmps) | `100.00% <100.00%> (ø)` | | | [src/regularized/soft\_argmax.jl](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/30?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None#diff-c3JjL3JlZ3VsYXJpemVkL3NvZnRfYXJnbWF4Lmps) | `100.00% <100.00%> (ø)` | | | ... and [2 more](https://app.codecov.io/gh/axelparmentier/InferOpt.jl/pull/30?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=None) | |

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

BatyLeo commented 1 year ago

Current state of this pull request:

Perturbed (additive and multiplicative)

Seems to work, by imitation (Fenchel-Young loss) as well as by experience (PushForward)

SPO+

Works, but should only really work with $\alpha = 2$. Indeed, with gradient and loss values are wrong in theory when $\alpha\neq 2$. It could be easily fixed for $\alpha\neq 1$, but we would need to be able to compute $\arg\max_y \theta^\top g(y)$ for $\alpha = 1$.

BatyLeo commented 1 year ago

Now that #80 and #65 are merged, I think this PR is finally ready to be merged. I just need some feedback about the following:

gdalle commented 1 year ago

I think support for Regularized would be nice. Does it also work with the latest Perturbed revamp?

BatyLeo commented 1 year ago

I think support for Regularized would be nice. Does it also work with the latest Perturbed revamp?

gdalle commented 1 year ago

I could probably make it compatible with AbstractRegularized tho, even if the Frank Wolfe forward pass won't work as expected.

Indeed FW will fail, but the whole idea of AbstractRegularized was to allow for custom solvers, so I think it's worth the adaptation

BatyLeo commented 1 year ago

This would need a way for getting the maximizer of an AbstractRegularized, should we add one method to the AbstractRegularized interface? It would only be used when using GeneralizedMaximizer, that's a bit weird

gdalle commented 1 year ago

It would only be used when using GeneralizedMaximizer, that's a bit weird

Is it a method that's easy to code?

BatyLeo commented 1 year ago

Yes, it would only need the following interface:

function get_maximizer end

We just need to specify that it needs to be implemented only for AbstractRegularized supporting GeneralizedMaximizer

gdalle commented 1 year ago

Your call. Btw you can also use RequiredInterfaces.jl to test the interfaces, it's really easy and convenient

BatyLeo commented 1 year ago

I added a new abstract type AbstractRegularizedGeneralizedMaximizer <: AbstractRegularized with the additional needed interface using RequiredInterfaces.jl. Now GeneralizedMaximizer can be regularized!

LouisBouvier commented 1 year ago

I see no issue with this PR. Nice job!