vgvassilev / clad

clad -- automatic differentiation for C/C++
GNU Lesser General Public License v3.0
280 stars 123 forks source link

Trying to differentiate an operator in reverse mode fails. #917

Open MihailMihov opened 4 months ago

MihailMihov commented 4 months ago

Test:

#include "clad/Differentiator/Differentiator.h"

class SimpleFunctions1 {
public:
  SimpleFunctions1() noexcept : x(0), y(0) {}
  SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y) {}
  double x;
  double y;
  double mem_fn_1(double i, double j) { return (x + y) * i + i * j * j; }
  SimpleFunctions1 operator+(const SimpleFunctions1& other) const {
    return SimpleFunctions1(x + other.x, y + other.y);
  }
};

double fn_s1_operator(double i, double j) {
  SimpleFunctions1 obj1(2, 3);
  SimpleFunctions1 obj2(3, 5);
  return (obj1 + obj2).mem_fn_1(i, j);
}

int main() {
  auto d_fn_s1_operator = clad::gradient(&fn_s1_operator);
}

Clang invocation:

/home/mihail/dev/llvm-18.1.5/build/bin/clang-18 -cc1 -triple x86_64-unknown-linux-gnu -emit-obj -mrelax-all -dumpdir NonDifferentiable.out- -disable-free -clear-ast-before-backend -main-file-name NonDifferentiable.C -mrelocation-model pic -pic-level 2 -pic-is-pie -mframe-pointer=all -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -fdebug-compilation-dir=/home/mihail/dev/clad/build -fcoverage-compilation-dir=/home/mihail/dev/clad/build -resource-dir /home/mihail/dev/llvm-18.1.5/build/lib/clang/18 -I ../include -internal-isystem /usr/lib/gcc/x86_64-pc-linux-gnu/13/include/g++-v13 -internal-isystem /usr/lib/gcc/x86_64-pc-linux-gnu/13/include/g++-v13/x86_64-pc-linux-gnu -internal-isystem /usr/lib/gcc/x86_64-pc-linux-gnu/13/include/g++-v13/backward -internal-isystem /home/mihail/dev/llvm-18.1.5/build/lib/clang/18/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-pc-linux-gnu/13/../../../../x86_64-pc-linux-gnu/include -internal-externc-isystem /include -internal-externc-isystem /usr/include -fdeprecated-macro -ferror-limit 19 -fgnuc-version=4.2.1 -fskip-odr-check-in-gmf -fcxx-exceptions -fexceptions -fcolor-diagnostics -add-plugin clad -load /home/mihail/dev/clad/build/./lib/clad.so -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/NonDifferentiable-0608af.o -x c++ NonDifferentiable.C

Error:

NonDifferentiable.C:18:16: error: too few arguments to function call, expected 4, have 3
   15 | double fn_s1_operator(double i, double j) {
      |        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   16 |   SimpleFunctions1 obj1(2, 3);
      |   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   17 |   SimpleFunctions1 obj2(3, 5);
      |   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   18 |   return (obj1 + obj2).mem_fn_1(i, j);
      |   ~~~~~~~~~~~~~^

The operator_plus_pullback that Clad is generating is:

void operator_plus_pullback(const SimpleFunctions1 &other, SimpleFunctions1 _d_y, SimpleFunctions1 *_d_this, SimpleFunctions1 *_d_other) const;

I don't understand what the error here is, so I'm not sure how to go about fixing this, but I discovered this when creating tests for [[clad::non_differentiable]] in reverse mode and this error is now blocking that PR.

vgvassilev commented 4 months ago

Don't we need to just pass the 4 arguments instead of 3?

gojakuch commented 3 months ago

from my limited experience of working with operator differentiation in lambda functions, this error typically means that the constructed list of arguments for the pullback is ill-formed. the pullbacks normally accept a pointer to the derivative object of the class. I assume this may not be the case for non-differentiable types, which may result into this mismatch (if the VisitCallExpr method doesn't realise that we don't have to pass any pointers to the derivative objects to the pullback, it's gonna add this argument to the list). from what I also understand, clang does not count the first invisible argument, which is the object that calls the method itself, so in the AST it probably expects 4 args and receives 5 actually (if I remember correctly). not sure if it's exactly the case, but maybe that helps. it's only my assumption though and maybe you've already checked all of this and realised the issue