Closed fel-thomas closed 1 year ago
Sounds good to me!
I think it might get a little bit trickier for the CGD IHVP implementation, as tf.autodiff.ForwardAccumulator is a bit less stable than the typical GradientTapes, but we should be able to find a workaround ;)
The benchmark shows that the speed in our use cases soars when we vectorize our critical functions well with tf.function. I propose to go over the code before any new implementation and make sure we maximize the code in tf.function.
This will require moving / splitting the 'branching code' (e.g assert on dataset...) from the computation code.
@Agustin-Picard , I think we can help on that with @lucashervier as we already did it in Xplique. I think we can propose you something, a draft or a V0 and we can discuss if it fits with your vision of the lib. ;)