kokkos / stdBLAS

Reference Implementation for stdBLAS
Other
112 stars 21 forks source link

symmetric_matrix_rank_k_update != dsyrk #272

Closed prlw1 closed 1 month ago

prlw1 commented 1 month ago

As mentioned at the end of #261, I tried the example posted by @yigiter in https://github.com/kokkos/stdBLAS/issues/261#issuecomment-1873645288_ which lead me to wonder what symmetric_matrix_rank_k_update is computing.

Given its name, I assume that symmetric_matrix_rank_k_update maps to the BLAS function dsyrk (but with a single scalar constant alpha, rather than BLAS's alpha and beta) described as:

DSYRK performs one of the symmetric rank k operations

C := alphaAA*T + betaC,

or

C := alpha*A*TA + beta*C,

where alpha and beta are scalars, C is an n by n symmetric matrix and A is an n by k matrix in the first case and a k by n matrix in the second case.

I reused @yigiter's example as follows:

#define MDSPAN_USE_PAREN_OPERATOR 1

#include <type_traits>
#include <mdspan/mdspan.hpp>
#include <experimental/linalg>
#include <fmt/ranges.h>
extern "C" {
        #include <cblas.h>
}

using namespace MDSPAN_IMPL_STANDARD_NAMESPACE;
using namespace MDSPAN_IMPL_STANDARD_NAMESPACE :: MDSPAN_IMPL_PROPOSED_NAMESPACE;
using namespace MDSPAN_IMPL_STANDARD_NAMESPACE :: MDSPAN_IMPL_PROPOSED_NAMESPACE :: linalg;

void print(auto const& M)
{
        for (int i = 0; i < M.extent(0); ++i) {
                for (int j = 0; j < M.extent(1); ++j) {
                        fmt::print("{:3} ", M(i,j));
                }
                fmt::print("\n");
        }
}

int main()
{
        double lgAvec[6]={1,1,1,2,2,2};
        double lgCvec[9]={1,2,3,2,4,5,3,5,6};
        mdspan lgA(lgAvec,3,2);
        mdspan lgC(lgCvec,3,3);
        fmt::print("============== linalg ==============\n");
        fmt::print("============== before ==============\n");
        fmt::print("A=\n");print(lgA);
        fmt::print("C=\n");print(lgC);
        symmetric_matrix_rank_k_update(lgA, lgC, upper_triangle);
        fmt::print("============== after  ==============\n");
        fmt::print("A=\n");print(lgA);
        fmt::print("C=\n");print(lgC);

        fmt::print("============== openblas ============\n");
        double blAvec[6]={1,1,1,2,2,2};
        double blCvec[9]={1,2,3,2,4,5,3,5,6};
        mdspan blA(blAvec,3,2);
        mdspan blC(blCvec,3,3);
        fmt::print("============== before ==============\n");
        fmt::print("A=\n");print(blA);
        fmt::print("C=\n");print(blC);
        cblas_dsyrk(CblasRowMajor, CblasUpper, CblasNoTrans, 3, 2, 1.0, blAvec, 2, 1.0, blCvec, 3);
        fmt::print("============== after  ==============\n");
        fmt::print("A=\n");print(blA);
        fmt::print("C=\n");print(blC);
}

which generates the output:

============== linalg ==============
============== before ==============
A=
  1   1 
  1   2 
  2   2 
C=
  1   2   3 
  2   4   5 
  3   5   6 
============== after  ==============
A=
  1   1 
  1   2 
  2   2 
C=
  4   6   3 
  6  12   5 
  3   5   6 
============== openblas ============
============== before ==============
A=
  1   1 
  1   2 
  2   2 
C=
  1   2   3 
  2   4   5 
  3   5   6 
============== after  ==============
A=
  1   1 
  1   2 
  2   2 
C=
  3   5   7 
  2   9  11 
  3   5  14 

I expect the answer

    3    5    7
    5    9   11
    7   11   14

as per the blas function (upper triangle).

Given some matrix A, A A^T is going to be symmetric, so the symmetric constraint needs to be on the additive matrix C in order to avoid a triangle of computations, The comment in https://github.com/kokkos/stdBLAS/blob/06e90a58b67c5adefff0f06904c8f8bc3371815b/include/experimental/__p1673_bits/blas3_matrix_rank_k_update.hpp#L152 seems to suggest that A is the matrix which needs to be symmetric?

I am puzzled as to what symmetric_matrix_rank_k_update is meant to be computing?

mhoemmen commented 1 month ago

Hi! Last week was a short week, so I didn't have a chance to take a look. I'll try to do that as soon as I can. Thanks!

mhoemmen commented 1 month ago

PR #275 looks like it fixes this -- thanks! : - )

prlw1 commented 4 weeks ago

Given the state of affairs rasolca@ discovered, there is just one doubt left: is the fact that dsyrk takes two scalars, but symmetric_matrix_rank_k_update takes one scalar (no beta) intentional (because one can easily premultiply C), or an oversight?

mhoemmen commented 4 weeks ago

@prlw1 It's probably an oversight, based on _DSYR not having a beta parameter. Thanks for pointing this out! It's pretty easy to add that overload to {symmetric,hermitian}_matrix_rank_[2]k_update without breaking backwards compatibility (as one could always constrain that scalar template parameter for beta not to be mdspan).

mhoemmen commented 2 weeks ago

@prlw1 I've filed https://github.com/ORNL/cpp-proposals-pub/issues/468 to track the {symmetric,hermitian}_matrix_rank_[2]k_update issue. I actually think this needs to be fixed before C++26, as I explain at that link. Please feel free to comment there if you wish; thanks!