Open pat-alt opened 1 year ago
I appreciate that this is a very ambitious idea
Yeah. I think this is a very noble goal but indeed challenging. Still, given the nobility of the goal, I think it's definitely worth scoping out where the issues lie.
I'm not sure what the particluar issue raised above is.
The problem is that MLJ started when Flux was still in relative infancy (no Zygote) and there's a lot of mutation where Zygote just spits the dummy.
When I last played with this, I ran into a rather serious obstacle for probabilistic classifiers. The implementation of UnivariateFinite
(now at CategoricalDistributions.jl) uses (mutable) dictionaries, which Zygote did not like. I wonder if this is stilll the case? For example, can I differentiate
p -> pdf(UnivariateFinite(["x', "y"], p, pool=missing), "x")
which is equivalent to p -> p[1]
for vectors p
?
Thanks for sharing your thoughts on this, Anthony.
We'll be looking at this in the coming weeks/months and I have no doubt we'll run into lots of issues related to mutation. Nonetheless, I think it's worth exploring. I think MLJFlux
is a good starting point, since CounterfactualExplanations
is currently tailored to Flux
. Alternatively, the logistic classifier from MLJLinearModels
also seems like a natural first candidate. From there we're most interested in adding support for tree-based models, which will most likely involve a detour to classifier calibration.
If it's alright, I'll keep this open for now and we may come back here with updates.
Motivation and description
Maybe this is a more general topic for
MLJ
, not only related toFlux
. I know that autodiff has been discussed in the past and withMLJFlux
now being developed, I was wondering if this topic has come back into focus.In an ideal world, it would be possible to differentiate through any
SupervisedModel
and get gradients with respect to parameters or inputs. This would, for example, greatly increase the scope of models we can explain through Counterfactual Explanations (see plans outlined here).MLJFlux
seems like a good place to start, since the underlying models are compatible withZygote
. But even here we quickly run into issues: for example, it does not seem possible to differentiate through apredict
call.An example:
Both
f
andg
can be used to return softmax output forx
Autodiff only works for
g
,but not for
f
:A simple workaround for this specific issue is to just use the
Chain
directly to produce the softmax output but this approach does not generalise to otherMLJ
models.I appreciate that this is a very ambitious idea (perhaps previous discussions have that this is simply asking too much), but I would be curious to hear what others think.
Worth mentioning that for the plans mentioned above, I will get some support from a group of CS students soon. So if you have any plans or ongoing work in this space anyway, perhaps there's something we can help with.
Thanks!
Possible Implementation
No response