Due to the following issue with matrix inversion in PyTorch: https://github.com/pytorch/pytorch/issues/134334
the TRAK code will deadlock if there are any NaNs in the computed gradients. This is apparently expected behaviour, as the behaviour of torch.inv is undefined if there are NaNs.
The expected result would be a raised error. Adding NaN checks for computed tensors might be a good idea.
Due to the following issue with matrix inversion in PyTorch: https://github.com/pytorch/pytorch/issues/134334 the TRAK code will deadlock if there are any NaNs in the computed gradients. This is apparently expected behaviour, as the behaviour of
torch.inv
is undefined if there are NaNs.The expected result would be a raised error. Adding NaN checks for computed tensors might be a good idea.