google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Inefficient jacobian computation for embedding layers. #199

Open mohamad-amin opened 9 months ago

mohamad-amin commented 9 months ago

Hello,

When I try to compute the NTK of a model with an embedding layer, I get the following warning:

/usr/local/lib/python3.10/dist-packages/neural_tangents/_src/empirical.py:2215: UserWarning: No Jacobian rule found for gather.
  warnings.warn(f'No Jacobian rule found for {primitive}.')

And ntk computation fails, due to OOM errors. This is a reproduction: https://colab.research.google.com/drive/1Z8ClXo85VjNEoKmWYHsS5dNccZ-Xf_JS?usp=sharing

romanngg commented 9 months ago

Thanks for pointing this out and the repro! Yes structured derivatives don't have structure annotation / jacobian implementation for scatter/gather primitives, and would be very inefficient currently (so I recommend using methods 1/2); will take a look and see if it can be improved.