Closed pat-alt closed 5 months ago
MLJFlux
is probably the most obvious place to start for this (see related discussion here)
This is in principle now implemented (#450), but by default MLJ models are assumed to be non-differentiable (the MLJBase.predict
call and other functions don't play nicely with Zygote)
This is an issue reserved for the TU Delft Student Software Project '23
MLJ is a popular machine learning framework for Julia. It ships with a large suite of common machine learning models, both supervised and unsupervised. It seems natural to interface this package to MLJ, although currently differentiability is a major challenge: to be able to use any of our counterfactual generators to explain MLJ models, those models need to be differentiable with respect to features. Still, this is worth exploring.
I propose the following steps:
AbstractFittedModel
forMLJ.Supervised
)MLJModel<:AbstractFittedModel
class that can handle all (compatible) supervised MLJ models. To this end, we will need a mechanism to differentiate between compatible and incompatible models.This is a challenging task and it is not critical that you succeed at everything. But we would like to aim for the following minimum achievements:
Previous attempts
I have tried this in the past, which might or might not be a good starting point:
Zygote.jl
for auto-diff. Not sure if all MLJ models can just be "auto-diffed" in that sense, but some early experiments withEvoTrees
has shown that in principal gradient-based counterfactual generators should be applicable (see here).Zygote.jl
didn't work in this case and I had to rely onForwardDiff
(see here). The problem with trees is that the counterfactual loss function is not smooth and hence taking gradients just resulted in gradients with all elements equal to zero (at least I think the non-smoothness was the issue here). Would still be preferable to use Zygote if possible.