Open grimmmyshini opened 2 years ago
@parth-07, I tried the fix I had in mind, however, it causes 3 tests to fail (Mainly the user-derived one and the functor one) I tried fixing them but have no idea of how it is supposed to be. Could you please look into this? I can put in a WIP PR with the changes I made if you would like.
Hi Garima,
I don't think there is any trivial fix for the problem described in this issue. The possible fix stated by you would not work for all cases. We cannot rely on compiler facilities of type strictness for derivative parameters, at least not without making major design changes. The core of the issue is the feature of specifying which derivatives should be computed.
For example, consider a simple function fn
:
double fn(float f, double d, long double ld);
Please note that the overloaded gradient calls the actual gradient.
Now consider that user has made the following differentiation request:
auto fn_grad_ld = clad::gradient(fn, "ld");
long double d_ld = 0;
fn_grad_ld.execute(3, 5, 7, &d_ld);
fn_grad_ld.execute
calls the overloaded function and the overloaded function in turn calls the actual gradient.
The whole process can be described like this:
void fn_grad_2(float f, double d, long double ld, clad::array_ref<void> _d_ld, clad::array_ref<void> _d_temp0, clad::array_ref<void> _d_temp1) {
return fn_grad_2(f, d, ld, _d_ld);
}
// this is called by 'fn_grad_ld.execute(...)`
fn_grad_2(f, d, ld, &d_ld, nullptr, nullptr);
Please note that it is necessary for all derivative parameters to be of the same type here for the overloaded gradient technique to work.
If the overloaded gradient function keeps the original type of each derivative parameter then the argument of type clad::array_ref<float>
would be expected where &d_ld
is being passed. &l_ld
being of type long double*
cannot be converted to clad::array_ref<float>
type.
There are two main ways to fix this issue:
1) Remove the functionality of specifying which derivatives should be computed.
or
2) Instead of relying on Clang for type strictness, we can modify clad such that it can check each CladFunction::execute
call to make sure arguments are of correct type, and issue errors if they are not of correct types.
Couldn't we have this check as a Clad warning rather than using metaprogramming?
Couldn't we have this check as a Clad warning rather than using metaprogramming?
Yes, we can. But if we are checking this, then we should issue an error instead of a warning for incompatible types (For example: float *
used where double *
is expected).
yeah, that works. Can you point @grimmmyshini to the are of the code she can do this?
We would need to find a way, on Clad side of things, to associate every CladFunction
object with the corresponding derived function it stores. Then we would need to traverse the AST for CladFunction::execute
function calls. I just realized that this method would only work for simple cases. It would not work for cases where we cannot determine which derived function is stored by a particular CladFunction
object. For example, please consider this code:
auto fn_grad_i = clad::gradient(fn, "i");
auto fn_grad_j = clad::gradient(fn, "j");
decltype(fn_grad_i) selected_grad;
if (someCondition)
selected_grad = fn_grad_i;
else
selected_grad = fn_grad_j;
We cannot tell which derived function is stored by selected_grad
object at compile time.
Minimum reproducible example:
Possible Fix: In function traits, replace the following
OutputParamType_t<Args, void>
witharray_ref<typename std::remove_pointer<Args>::type>