Open MarioAuditore opened 1 year ago
When you use
@implicit_diff.custom_root(optimality_fun)
def fun(...):
fun
and optimality_fun
should have the same number of arguments, which is not the case here.
Your function euclidean_weighted_mean
includes non-differentiable arguments like n_iter
and plot_loss_flag
. You need to remove them.
def make_euclidean_weighted_mean(lr = 0.1, n_iter = 50, plot_loss_flag = False):
def euclidean_weighted_mean(x, X, weights):
[...]
Hello! I tried to implement the example of implicit differentiation as shown here but with my own functions. The task is to find mean for a set of vectors named X via gradient descent.
Algorithm for finding mean:
You can launch it like this:
As you can see, I am calculating the weighted version of mean and that's where I use jaxopt. Let me define the global objective (just as an example): I want the weights have the value, which minimises the distance between the resulting mean and the desired point. In my case, I want the weights to influence the algorithm in such a way, that the resulting mean will be as close to
X[0]
as possible:The problem emerges when I call
Meanwhile the official example with Ridge regression works perfectly. Any suggestions?