Closed souryadey closed 2 years ago
Specific operations causing NaN
gradients can be detected by including torch.autograd.set_detect_anomaly(True)
in your run script.
Possible cause:
Some singular values are very close to each other, or very small during the SVD computation prior to rank truncation. This makes their gradients explode.
Solution:
To avoid large gradient values, initialize StatePred
with gradient clippers by setting e.g. clip_grad_norm = 10.
and clip_grad_value = 5.
. These are features supported by Pytorch, and described in more detail here and here.
Possible cause:
Rank is too high - this makes condition number and largest eigenvalue very large.
Solution:
cond_threshold
argument of StatePred.train_net()
to lower than the default value, e.g. $10$ or $2$. This helps in discarding low singular values for the pseudo inverse computation.
Running
StatePred.train_net()
results inNaN
values in the gradient.