Closed qixuanf closed 3 years ago
Hi @qixuanf, thanks for the message. You've definitely found a bug! I think the issue is that we introduced diagonal_axes
into the empirical kernel but didn't fully build out support for it in the batching code. For now, can you try setting diagonal_axes=()
? I'll try to submit a fix soon, but it may take slightly longer than normal since NeurIPS is going on at the moment.
Thanks! Sure, I will not use this argument for the moment.
Is there an update on this. Setting diagonal_axes=()
is not really a perfect solution as in many settings it increases evaluation time quite considerably, and unnecessary when off-diagonal entries are not needed.
Thanks for pinging this, it had slipped my mind. Batching should now work correctly with diagonal_axes
and trace_axes
for empirical kernels as of fd1611660c87edcb0c2e50403f691b60d2cc252b. If something still seems off, please don't hesitate to let us know!
Will close this as fixed, please create a new issue if something is still not working!
As can be seen in the example below, I expected that decorating a kernel function with
batch
doesn't change the return shape of the kernel function, i.e. bothres_5
andres_10
below should have shape (50, 50, 10).The version of neural_tangent is 0.3.5