FluxML / NNlib.jl

Neural Network primitives with multiple backends
Other
202 stars 123 forks source link

More batched functions such as batched_svd, batched_diagm #401

Open sychen52 opened 2 years ago

sychen52 commented 2 years ago

In order to write a multiple view geometry and make it easy to used in deep learning, I think the input and output tensor should have a batch dimension. And I need a few batched versions of the linear algebra functions, such as torch.bmm, torch.svd, torch.diag_embd. I traced into the NNlib module and noticed it has batched_mul, batched_transpose/adjoint, but not svd, diagm.

Is this the correct place to add these batched version functions?

sychen52 commented 2 years ago

Or is there a more elegant solution than writing all these batched functions and adding them in NNlib?

mcabbott commented 2 years ago

@Roger-luo was going to collect many of these in https://github.com/Roger-luo/BatchedRoutines.jl but that was a while ago. Otherwise here sounds OK to me.

batched_mul got pretty complicated as it wanted to allow all PermutedDimsArray which could use BLAS to do so, possibly a mistake. If batched_svd is less ambitious it could be much simpler. But what would it return?

Are there batched CUDA versions of these functions?

Roger-luo commented 2 years ago

There's a few batched linear algebra functions in magma which we tried to wrap in Julia but had issue with BB. See Magma.jl

sychen52 commented 2 years ago

Am I understanding this correctly? In order to have batched_svd in NNlib, 1) we need a batched cpu version using LAPACK (maybe put in BatchedRoutines.jl), 2) we also need a batched cuda version using MAGMA (put in Magma.jl), 3) then we unify the function api in NNlib by calling cpu and gpu version underneath.

ToucheSir commented 2 years ago

Judging by https://github.com/pytorch/pytorch/blob/5dbec7c07c5eedd748fd56359c2d1b980dfa1037/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp, the magma dep may not be required for GPU support.

nikopj commented 1 year ago

@ToucheSir is there any example code you could point to for starting on the path of writing batched_svd / how would you recommend approaching it?

ToucheSir commented 1 year ago

I'm almost certainly the one with the least linear algebra knowledge on this thread and thus not the one you want to ask such questions :)