Cytnx-dev / Cytnx

Project Cytnx, A Cross-section of Python & C++,Tensor network library
Apache License 2.0
35 stars 14 forks source link

linalg internal uses NULL pointers for undefined functions #487

Open manuschneider opened 1 month ago

manuschneider commented 1 month ago

Here is an example of the current behavior of many backend functions for different dtyes:

in src/backend/linalg_internal_interface.hpp, a vector containing function pointers is created:

std::vector<Qrfunc_oii> QR_ii;

Then, in src/backend/linalg_internal_interface.cpp, this vector is filled with the implementations for the different dtypes:

QR_ii = vector<Qrfunc_oii>(5);

QR_ii[Type.ComplexDouble] = QR_internal_cd;
QR_ii[Type.ComplexFloat] = QR_internal_cf;
QR_ii[Type.Double] = QR_internal_d;
QR_ii[Type.Float] = QR_internal_f;

I see two problems with this: first, the vector in this case only contains 5 entries, so accessing QR_ii[Type.Int] would lead to an error that does not tell the user much. Second, QR_ii contains 5 elements, but only 4 of them are initialized. The last one defaults to a NULL pointer, and when the function is called this leads to a segfault or kernel crash. This makes debugging very hard.

Here is a suggestion to add fallback implementations of all internal functions in order to make things more solid: In src/backend/linalg_internal_interface.cpp:

QR_ii = vector<Qrfunc_oii>(N_Type, QR_internal_fallback);

QR_ii[Type.ComplexDouble] = QR_internal_cd;
QR_ii[Type.ComplexFloat] = QR_internal_cf;
QR_ii[Type.Double] = QR_internal_d;
QR_ii[Type.Float] = QR_internal_f;

In src/backend/linalg_internal_cpu/QR_internal.hpp:

void QR_internal_fallback(const boost::intrusive_ptr<Storage_base> &in,
                        boost::intrusive_ptr<Storage_base> &Q,
                        boost::intrusive_ptr<Storage_base> &R,
                        boost::intrusive_ptr<Storage_base> &D,
                        boost::intrusive_ptr<Storage_base> &tau, const cytnx_int64 &M,
                        const cytnx_int64 &N, const bool &is_d);

In src/backend/linalg_internal_cpu/QR_internal.cpp:

void QR_internal_fallback(const boost::intrusive_ptr<Storage_base> &in,
                        boost::intrusive_ptr<Storage_base> &Q,
                        boost::intrusive_ptr<Storage_base> &R,
                        boost::intrusive_ptr<Storage_base> &D,
                        boost::intrusive_ptr<Storage_base> &tau, const cytnx_int64 &M,
                        const cytnx_int64 &N, const bool &is_d) {
  cytnx_error_msg(true, "[ERROR][linalg_internal] QR_internal not implemented for this data type", "\n");
}

However, this would have to be done for many internal functions in the backend.

IvanaGyro commented 1 month ago

Many xx_cd, xx_cf, ... functions, do similar things. The other way to ease the pain is letting the callers take the responsibility of type checking. And then templatize xx_cd, xx_cf, ... functions or use function overload: xx(int value) {}, xx(double value). By doing so, some error will be caught at the compile time.

manuschneider commented 1 month ago

Indeed, it would be possible to not provide these function pointer vectors and let the calling function handle things. But then the calling function needs to have some switch - case statement everywhere to call the correct function. I think the current implementation is more convenient, one can just call QR_ii[dtype] for a generic dtype. But the current implementation is not very robust because of the NULL pointers and vectors that are shorter than N_Type.

Function overloading would be a clean way to implement things, but does not work here unfortunately. The function arguments are tensors, and the implementation depends on their dtype. So the function arguments always have the same types.

The only alternative I can think of is to provide only one function that checks the dtype of the input tensors and changes the behavior accordingly/calls the correct function (maybe a bit slower though).

IvanaGyro commented 1 month ago

Instead of creating an invalid function for each function manually, we can use a template. This solution needs C++20.

#include <iostream>
#include <algorithm>

template<size_t N>
struct StringLiteral {
    constexpr StringLiteral(const char (&str)[N]) {
        std::copy_n(str, N, value);
    }

    char value[N];
};

template<StringLiteral Name, typename Return, typename... Args>
Return NullFunction(Args... args) {
    std::cout << "Calling a function pointing to nullptr:" << Name.value << std::endl;
    return Return();
}

Here is a demo.