JuliaTrustworthyAI / CounterfactualExplanations.jl

A package for Counterfactual Explanations and Algorithmic Recourse in Julia.
https://www.taija.org/CounterfactualExplanations.jl/
MIT License
120 stars 7 forks source link

Interface to `<: Supervised` MLJ models #69

Closed pat-alt closed 5 months ago

pat-alt commented 2 years ago

This is an issue reserved for the TU Delft Student Software Project '23

MLJ is a popular machine learning framework for Julia. It ships with a large suite of common machine learning models, both supervised and unsupervised. It seems natural to interface this package to MLJ, although currently differentiability is a major challenge: to be able to use any of our counterfactual generators to explain MLJ models, those models need to be differentiable with respect to features. Still, this is worth exploring.

I propose the following steps:

  1. Implement basic interface to MLJ (essentially have an AbstractFittedModel for MLJ.Supervised)
  2. From the MLJ model list, identify which ML models fulfil the differentiability criterium. Note that some models, like decision trees, may be differentiable after probability calibration. See below for a potential starting point. Start by focusing on pure Julia models, before dealing with non-native models (like sklearn).
  3. Ideally, I think we would like a single MLJModel<:AbstractFittedModel class that can handle all (compatible) supervised MLJ models. To this end, we will need a mechanism to differentiate between compatible and incompatible models.
  4. Thoroughly test and document your contributions.

This is a challenging task and it is not critical that you succeed at everything. But we would like to aim for the following minimum achievements:

Previous attempts

I have tried this in the past, which might or might not be a good starting point:

  1. At this point all of the counterfactual generators need gradient-access and currently leverage Zygote.jl for auto-diff. Not sure if all MLJ models can just be "auto-diffed" in that sense, but some early experiments with EvoTrees has shown that in principal gradient-based counterfactual generators should be applicable (see here).
  2. That being said, Zygote.jl didn't work in this case and I had to rely on ForwardDiff (see here). The problem with trees is that the counterfactual loss function is not smooth and hence taking gradients just resulted in gradients with all elements equal to zero (at least I think the non-smoothness was the issue here). Would still be preferable to use Zygote if possible.
  3. (Non-)Differentiability of models may be a more general issue.
pat-alt commented 1 year ago

MLJFlux is probably the most obvious place to start for this (see related discussion here)

pat-alt commented 5 months ago

This is in principle now implemented (#450), but by default MLJ models are assumed to be non-differentiable (the MLJBase.predict call and other functions don't play nicely with Zygote)