Closed Marcel-Rodekamp closed 2 years ago
Is there a concern about overflow / loss of precision with the change? Maybe we should keep logdet
for real types and only go to log•det
for complex types?
There is also torch.linalg.slogdet
which for complex tensors returns the angle + magnitude.
I have checked out the torch.linalg.slogdet
, unfortunately, it has a similar precision:
A = torch.rand((2,2),dtype=torch.cdouble)
# A = tensor([[0.6899+0.2257j, 0.6542+0.8042j],[0.0367+0.9547j, 0.9411+0.8004j]], dtype=torch.complex128)
s,r = torch.linalg.slogdet(A)
# s = tensor(0.9959+0.0909j, dtype=torch.complex128)
# r = tensor(0.1967, dtype=torch.float64)
print(f"{A.det().log():.20}")
# (0.19669927262935532619+0.090991579122977631067j)
print(f"{torch.log(s)+r:.20}")
# (0.19669927262935557599+0.090991579122977617189j)
Now I exported the matrix to numpy:
snp,rnp = np.linalg.slogdet(A.numpy())
print(f"{np.log(snp)+rnp:.20}")
# (0.19669927262935546497+0.090991579122977644944j)
which has the same output. I scaled this thing up to a 16 x 16 matrix:
A.det().log() = 128.26763144875246780+2.97543397892812900j
torch.log(s)+r = 128.26763144875246780+2.97543397892812811j
np.log(np.det(A))= 128.26763144875249623+2.97543397892812900j
np.log(snp)+rnp = 128.26763144875249623+2.97543397892812900j
and a 124 x 124 matirx:
A.det().log() = 128.26763144875246780+2.97543397892812900j
torch.log(s)+r = 128.26763144875246780+2.97543397892812811j
np.log(np.det(A))= 128.26763144875249623+2.97543397892812900j
np.log(snp)+rnp = 128.26763144875249623+2.97543397892812900j
with only a minimal (I'd say negligible) precision loss.
Corresponding interface in src/NSL/LinAlg/det.tpp
would look like:
template <typename Type>
Type logdet(const NSL::Tensor<Type> & t){
//! \todo if t is not a matrix we would have a stack of determinants: Handle this case.
//! \todo: add logdet as a Tensor member
if constexpr (NSL::is_complex<Type>()){
std::tuple<torch::Tensor,torch::Tensor> slogdet = torch::linalg::slogdet(t);
return std::log(std::get<0>(slogdet).template item<Type>()) + std::get<1>(slogdet).template item<Type>();
} else {
return torch::logdet(t).template item<Type>();
}
}
When I reproduce your 2x2 example in torch, numpy, and using infinite-precision arithmetic in Mathematica, I find that the agreement is to 15 digits after the decimal for the imaginary part and 16 for the real part. I can't assess the larger examples easily, so I can't swear to the accuracy.
I'm going to merge. We have an open issue #43 anyway, so as long as we keep that open I'm happy.
torch::logdet
does not work for complex dtype. Thus changed to callingtorch::log(torch::det(t))