RosettaCommons / binder

Binder, tool for automatic generation of Python bindings
MIT License
321 stars 66 forks source link

Modify reference to primitive type #261

Open mattangus opened 1 year ago

mattangus commented 1 year ago

I am trying to wrap up a library that uses returned references to private member variables in order to update those references. Here is an example

float & SomeClass::floatVal() { return _float_member; }
std::vector<float> & SomeClass::arrayVal() { return _array_member; }

The library has intended usage like the following. I don't have control over the library so I can't change the interface. Except for wrapping up each class, which defeats the purpose of binder IMO.

SomeClass cl;
cl.floatVal() = 0.1f;

I have set my config to use reference return policy

+default_member_lvalue_reference_return_value_policy pybind11::return_value_policy::reference_internal
+default_member_rvalue_reference_return_value_policy pybind11::return_value_policy::reference_internal

Which generates some code like

cl.def("floatVal", (float & (SomeClass::*)()) &SomeClass::floatVal, "C++: SomeClass::floatVal() --> float &", pybind11::return_value_policy::reference_internal);
cl.def("arrayVal", (std::vector<float> & (SomeClass::*)()) &SomeClass::arrayVal, "C++: SomeClass::arrayVal() --> std::vector<float> &", pybind11::return_value_policy::reference_internal);

The issue comes when I want to update the float value in python

cl = my_module.SomeClass()
cl.arrayVal()[:] = [1.0, 2., 3.] # this works fine because of the array slice
cl.arrayVal()[0] = 42. # this also works
cl.floatVal() = 1.23 # causes error
temp = cl.floatVal()
temp = 1.23 # replaces the referred value, does not update the underlying reference

The error is SyntaxError: cannot assign to function call here. Maybe you meant '==' instead of '='?. I can't seem to find any way to update primitive type references.

Does binder have any arguments to allow for this. Maybe a flag to generate a setter for a primitive type reference?

mattangus commented 1 year ago

I hacked together something. Maybe it'll be useful for others

string bind_function_setter(const string &module, FunctionDecl const *F, Context &context, CXXRecordDecl const *parent) {
    static vector< std::pair<string, string> > const name_map = {
        std::make_pair("enum ", ""),
        std::make_pair("class ", ""),
        std::make_pair("struct ", ""),
        std::make_pair("const enum ", "const "),
        std::make_pair("const class ", "const "),
        std::make_pair("const struct ", "const "),
    };

    string code;
    string function_name = python_function_name(F);
    string function_qualified_name = standard_name(parent ? class_qualified_name(parent) + "::" + F->getNameAsString() : F->getQualifiedNameAsString());
    CXXMethodDecl const *m = dyn_cast<CXXMethodDecl>(F);

    if (F->getReturnType()->isReferenceType()) {
        const clang::QualType &qt = F->getReturnType();
        const clang::QualType &nonRefQt = qt.getNonReferenceType();
        const clang::Type* nonRef = nonRefQt.getTypePtr();
        if (nonRef->isFundamentalType()) {

            string function, documentation;
            string maybe_static;
            if( m and m->isStatic() ) {
                maybe_static = "_static";
                function_name = Config::get().prefix_for_static_member_functions() + function_name;
                //outs() << "STATIC: " << function_qualified_name << " → " << function_name << "\n";
            }

            documentation = "setter for primitive reference return value {}"_format(F->getQualifiedNameAsString());

            string return_type = standard_name(nonRefQt);
            outs() << " making setter for " << function_name << " -> " << return_type << "\n";
            pair<string, string> args = function_arguments_for_lambda(F, 0);

            string input_args = "{}, {} value"_format(args.first, return_type);

            // workaround of GCC bug during lambda specification: replace enum/struct/class/const_* from begining of the lambda return type with //const*
            for( auto &p : name_map ) {
                if( begins_with(return_type, p.first) ) { return_type = p.second + return_type.substr(p.first.size()); }
            }

            string func;

            if( m and !m->isStatic() ) {
                // forcing object type to be of parent class so member function with lifted access could be used
                string object = class_qualified_name(parent ? parent : m->getParent()) + (m->isConst() ? " const" : "") + " &o";
                func = "[]({}{}) -> void {{ o.{}() = value; }}"_format(object, input_args, F->getNameAsString());
            }
            else {
                func = "[]({}) -> void {{ {}() = value; }}"_format(input_args, function_qualified_name);
            }

            code = module + R"(.def{}("set_{}", {}, "{}")"_format(maybe_static, function_name, func, documentation);
            code += "); \n";
        }
    }

    return code;
}
lyskov commented 1 year ago

@mattangus at the moment Binder does not options to tigger 'setter' generation for cases like this. If you do not have control of this package source code then possible workaround would be to use +add_on_binder config option to call your own binding code and inside it bind a desired 'setter' (perhaps defined as lambda function?). Hope this helps,