Open sychen52 opened 2 years ago
Or is there a more elegant solution than writing all these batched functions and adding them in NNlib?
@Roger-luo was going to collect many of these in https://github.com/Roger-luo/BatchedRoutines.jl but that was a while ago. Otherwise here sounds OK to me.
batched_mul
got pretty complicated as it wanted to allow all PermutedDimsArray
which could use BLAS to do so, possibly a mistake. If batched_svd
is less ambitious it could be much simpler. But what would it return?
Are there batched CUDA versions of these functions?
There's a few batched linear algebra functions in magma which we tried to wrap in Julia but had issue with BB. See Magma.jl
Am I understanding this correctly? In order to have batched_svd in NNlib, 1) we need a batched cpu version using LAPACK (maybe put in BatchedRoutines.jl), 2) we also need a batched cuda version using MAGMA (put in Magma.jl), 3) then we unify the function api in NNlib by calling cpu and gpu version underneath.
Judging by https://github.com/pytorch/pytorch/blob/5dbec7c07c5eedd748fd56359c2d1b980dfa1037/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp, the magma dep may not be required for GPU support.
@ToucheSir is there any example code you could point to for starting on the path of writing batched_svd / how would you recommend approaching it?
I'm almost certainly the one with the least linear algebra knowledge on this thread and thus not the one you want to ask such questions :)
In order to write a multiple view geometry and make it easy to used in deep learning, I think the input and output tensor should have a batch dimension. And I need a few batched versions of the linear algebra functions, such as torch.bmm, torch.svd, torch.diag_embd. I traced into the NNlib module and noticed it has batched_mul, batched_transpose/adjoint, but not svd, diagm.
Is this the correct place to add these batched version functions?