deel-ai / influenciae

👋 Influenciae is a Tensorflow Toolbox for Influence Functions
https://deel-ai.github.io/influenciae
Other
55 stars 3 forks source link

Extend the use of `tf.function` #6

Closed fel-thomas closed 1 year ago

fel-thomas commented 2 years ago

The benchmark shows that the speed in our use cases soars when we vectorize our critical functions well with tf.function. I propose to go over the code before any new implementation and make sure we maximize the code in tf.function.

This will require moving / splitting the 'branching code' (e.g assert on dataset...) from the computation code.

@Agustin-Picard , I think we can help on that with @lucashervier as we already did it in Xplique. I think we can propose you something, a draft or a V0 and we can discuss if it fits with your vision of the lib. ;)

Agustin-Picard commented 2 years ago

Sounds good to me!

I think it might get a little bit trickier for the CGD IHVP implementation, as tf.autodiff.ForwardAccumulator is a bit less stable than the typical GradientTapes, but we should be able to find a workaround ;)