kokkos / stdBLAS

Reference Implementation for stdBLAS
Other
128 stars 22 forks source link

Optimizing nested conjugated / transposed / scaled expressions #203

Open mhoemmen opened 2 years ago

mhoemmen commented 2 years ago

Discussion on PR #197 and elsewhere (e.g., with @youyu3 ) shows that it's tricky to optimize expressions like transposed(conjugated(scaled(alpha, A))) that result from calling e.g., matrix_product. "Optimize" here means "deduce that we can call an optimized BLAS routine." For layout_left A, the Fortran BLAS can handle this case directly, by setting TRANSA='C' and ALPHA=alpha.

It occurred to me that a "recursive" design could make this easier. I put "recursive" in quotes because it's based on function overloads; the calls to the function with the same name aren't actually recursive, because their arguments' types change on each nested call.

Here's some pseudocode:

enum class ETrans { N, T, H, C };

template<std::semiregular Scalar, ETrans Trans>
struct Extracted {
    static constexpr ETrans trans = Trans;
    std::optional<Scalar> scalar;
};

template<std::semiregular Scalar, ETrans Trans>
auto toggle_transpose(Extracted<Scalar, Trans> e)
{
    if constexpr (Trans == ETrans::N) {
        return Extracted<Scalar, ETrans::T>{e.scalar};
    } else if constexpr (Trans == ETrans::T) {
        return Extracted<Scalar, ETrans::N>{e.scalar};
    } else if constexpr (Trans == ETrans::H) {
        return Extracted<Scalar, ETrans::C>{e.scalar};
    } else { // ETrans::C
        return Extracted<Scalar, ETrans::H>{e.scalar};
    }
}

template<std::semiregular Scalar, ETrans Trans>
auto toggle_conjugate(Extracted<Scalar, Trans> e)
{
    if constexpr (Trans == ETrans::N) {
        return Extracted<Scalar, ETrans::C>{e.scalar};
    } else if constexpr (Trans == ETrans::T) {
        return Extracted<Scalar, ETrans::H>{e.scalar};
    } else if constexpr (Trans == ETrans::H) {
        return Extracted<Scalar, ETrans::T>{e.scalar};
    } else { // ETrans::C
        return Extracted<Scalar, ETrans::N>{e.scalar};
    }
}

template<std::semiregular InputScalar, std::semiregular Scalar, ETrans Trans>
auto add_or_replace_scalar(InputScalar s, Extracted<Scalar, Trans> e)
{
    return Extracted<InputScalar, ETrans>{s}; // discard current scalar in e
}

// omitting constraints on template parameters for brevity
template<class in_matrix_1_t, class Extracted1,
    class in_matrix_2_t, class Extracted2,
    class out_matrix_t,
    class in_matrix_1_original_t,
    class in_matrix_t_original_t>
void matrix_product_impl(
    in_matrix_1_t A, Extracted1 A_data,
    in_matrix_2_t B, Extracted2 B_data,
    out_matrix_t C,
    in_matrix_1_original_t A_original,
    in_matrix_2_original_t B_original)
{
    if constexpr (/* It's obvious we can't call the BLAS */) {
        // Early exit from the "recursion" avoids penalizing the generic case with higher compile times.
        matrix_product_fallback(A_original, B_original, C);
    }
    else if constexpr (/* A's outer layout is layout_transpose */) {
        matrix_product_impl(strip_nested_mapping(A), toggle_transpose(A_data),
            B, B_data, C, A_original, B_original);
    }
    else if constexpr (/* A's outer accessor is accessor_scaled */) {
        if(A_data.scalar.has_value()) {
            // ... check at compile time that it makes sense to multiply the two scaling factors, else fall back ...
            matrix_product_impl(strip_nested_accessor(A),
                add_or_replace_scalar(A.accessor().scaling_factor() * A.data.scalar.value(), A_data),
                B, B_data, C, A_original, B_original);
        } else {
            matrix_product_impl(strip_nested_accessor(A), add_or_replace_scalar(A.accessor().scaling_factor(), A_data),
                B, B_data, C, A_original, B_original);
        }
    }
    else if constexpr (/* A's outer accessor is accessor_conjugate */) {
        matrix_product_impl(strip_nested_accessor(A), toggle_conjugate(A_data),
            B, B_data, C, A_original, B_original);
    }
    // ... repeat the above pattern for B ...
    else if constexpr (/* all the types are BLAS friendly */) {
        // ETrans is a template parameter of Extracted so we can check most BLAS compatibility at compile time.
        // Extracted1 and Extracted2 are template parameters so that we don't force Scalar type conversion.
        // Some mixed-precision BLAS implementations (e.g., cuGemmEx) permit the Scalar type
        // to have a different type than the matrices' value types.

        if (/* any run-time decision whether we can call BLAS, e.g., layout_stride run-time strides */) {
            // ... call the BLAS using scalar and transpose from both Extracted structs ...
        } else {
            matrix_product_fallback(A_original, B_original, C);
        }
    }
    else {
        matrix_product_fallback(A_original, B_original, C);
    }
}

void matrix_product(in_matrix_1_t A, in_matrix_2_t B, out_matrix_t C)
{
    // "Recursive" calls may change the Extracted Scalar type.
    matrix_product_impl(A, Extracted<typename in_matrix_1_t::value_type, ETrans::N>{},
        B, Extracted<typename in_matrix_2_t::value_type, ETrans::N>{},
        C, A, B);
}

For a C++14 - compatible implementation, one could use function overloads (partial specialization) instead of if constexpr.

Here are some issues with the above approach.

  1. It has to construct mdspan once per "recursion" level.
  2. It increases the function call depth, which may interfere with inlining in the fall-back case.

We can fix at least (2) by applying the recursive approach to each pair (A, A_data) and (B, B_data). This will bound the function call depth for the fall-back case.

template<class InMatrix, class ExtractedType>
auto extract(std::tuple<InMatrix, ExtractedType>);

Regarding (1), we can mitigate this by limiting the "recursion" depth for cases that the BLAS obviously can't handle. Also, taking the mdspan by value lets us move-construct the pointer, layout, and accessor at each level, so we can reduce cost for the (admittedly unusual) case where any of these are expensive to construct.

fnrizzi commented 2 years ago

@mhoemmen wow you drafted this really fast!

youyu3 commented 2 years ago

Great pseudocode! It would be great to have something that works for arbitrarily nested conjugated/transposed/scaled (if we don't limit the recursion depth).

One quick question: don't we need to check if it's layout_left or layout_right, in addition to layout_transpose, do figure out if a transpose is needed when calling BLAS?

Personally, I'd prefer to resolve the effective trans and alpha (maybe some recursive logic on the layout/accessor) and then have a single call to the impl function :).

mhoemmen commented 2 years ago

@youyu3 wrote:

... don't we need to check if it's layout_left or layout_right, in addition to layout_transpose, do figure out if a transpose is needed when calling BLAS?

The body of the else if constexpr (/* all the types are BLAS friendly */) branch would have all of this logic, which I omitted in the above pseudocode. The idea is to strip off all the layout_transpose, accessor_conjugate, and accessor_scaled first. (This would reduce transposed(scaled(alpha, transposed(A))) to (alpha, A), for example.) Once the code reaches a known layout and accessor, then it can look at the actual layouts and transpose information. For example, if A is layout_right but trans is ETrans::T, and B is layout_left, then the code can view A as layout_left and use TRANSA='N' with the BLAS.

This design thus factors extracting information from conjugated / scaled / transposed etc., from deducing whether the extracted information and the underlying matrix types are compatible with the BLAS (or other library, like cuBLAS). It also concentrates the "can I call the BLAS?" logic in one place (or two, counting if constexpr (/* It's obvious we can't call the BLAS */), though it would be best for that to have only obviously library-independent things like "matrix value types are not arithmetic types").

Personally, I'd prefer to resolve the effective trans and alpha (maybe some recursive logic on the layout/accessor) and then have a single call to the impl function :).

Absolutely, a good idea. That's what "[w]e can fix at least (2) by applying the recursive approach to each pair (A, A_data) and (B, B_data)" was getting at. Deducing trans and alpha could also be done by working directly with the layout mapping and accessor; that could save a bit of time creating the actual mdspan. (The layout mapping and accessor have to be considered together, because both affect trans.)

mhoemmen commented 2 years ago

@fnrizzi wrote:

wow you drafted this really fast!

Awww : - ) All these discussions with y'all have been helpful! I wish I had figured this out sooner so you wouldn't have had to come up with those patterns on your own.

mhoemmen commented 2 years ago

Notes to implementers (e.g., @youyu3 ):

  1. Unlike the BLAS' TRANS argument, P1673's algorithms always work on the input matrix, not on the "original matrix." (mdspan layouts aren't always invertible, so there may not be an "original matrix.") This rule, while inconsistent with the BLAS, is internally consistent.

  2. P1673R8/9 transposed specification only produces layout_transpose for custom layouts not in P0009 or P1673. For example, transposed turns a layout_left matrix into a layout_right matrix. This is a classic C BLAS implementation technique (support row major by treating it as the transpose of column major). However, this means that the above code will need changes in order to detect when the TRANS argument should be set. For example, the implementation may need to wait until the last pre-calling-the-BLAS step before trying to find a common layout among all the arguments. "Finding a common layout" means, e.g., in matrix-matrix multiply, if A is row major and B is column major and C is row major, we may need to "transpose C in place" so that we can call a Fortran BLAS, and thus may need to do the same thing and/or change TRANSA or TRANSB.

mhoemmen commented 2 years ago

Note: the merged PR #238 fully implements conjugated, scaled, transposed, and conjugate_transposed as specified in P1673R9.

People interested in this issue (e.g., @youyu3, @fnrizzi, @MikolajZuzek ) should note that transposed now only returns layout_transpose for layouts unknown to both P0009 and P1673. This means that that else if constexpr (/* A's outer layout is layout_transpose */) branch in the above example won't work straightforwardly. Implementations will have freedom to decide how to toggle the layouts of A, B, and C in matrix_product, for instance. Thus, a recursive implementation would need some "preferred direction" for toggling layouts (e.g., prefer the same layout as the output matrix), to avoid infinite recursion.