ORNL / cpp-proposals-pub

Collaborating on papers for the ISO C++ committee - public repo
26 stars 27 forks source link

std::linalg: Fix {symmetric,hermitian}_matrix_rank_[2]k_update #468

Closed mhoemmen closed 3 months ago

mhoemmen commented 3 months ago

Please see comment here: https://github.com/kokkos/stdBLAS/issues/272#issuecomment-2248273146 .

DSYRK computes $C := \beta C + \alpha A A^T$, but symmetric_matrix_rank_k_update only computes $C := C + A A^T$ or $C := C + \alpha A A^T$. There's currently no way to express the $\beta$ scaling factor except for the $\beta = 1$ special case. This is a general issue for all the {symmetric,hermitian}_matrix_rank_[2]k_update functions.

Functions like matrix_vector_product have "overwriting" ($y := A x$) and "updating" ($z := y + Ax$) overloads. These handle the general $\alpha$ and $\beta$ case by matrix_vector_product(A, scaled(alpha, x), scaled(beta, y), y); this is explicitly permitted ("Remarks: z may alias y"). The distinction between "overwriting" and "updating" overloads lets std::linalg treat the $\beta = 0$ case differently than the BLAS: a deliberate design choice that permits Inf and NaN propagation and simplifies implementations (as they no longer need to branch for the $\beta = 0$ case).

The issue with functions like symmetric_matrix_rank_k_update is that they have the same syntax as the "overwriting" version, but different semantics. This means that we can't just fix this problem by introducing an "updating" overload for each function. (The fact that these functions have "update" in their name is not relevant, because that naming choice is historical. The BLAS Standard calls its corresponding routines "{Symmetric, Hermitian} rank {one, two} update.")

We propose to make these functions work just like matrix_vector_product or matrix_product, by adding "updating" overloads, and changing the existing overloads to be "overwriting." For example, symmetric_matrix_rank_k_update(A, C, upper_triangle) will perform $C := A A^T$ instead of $C := C + A A^T$. This is a breaking change; it will change the behavior of existing functions, without changing their syntax. Thus, we must do it before finalization of C++26.

What about e.g., symmetric_matrix_rank_1_update? These functions are fine, because corresponding BLAS 2 functions like DSYR don't take $\beta$. DSYR performs $A := \alpha x x^T + A$. This is probably the reason for the current symmetric_matrix_rank_k_update design: pure analogy with symmetric_matrix_rank_1_update, without considering that the BLAS 3 versions of these functions take $\beta$ and the BLAS 2 functions do not.

mhoemmen commented 3 months ago

Fix submitted as P3371R0. Thanks! : - )