albarji / proxTV

Matlab and Python toolbox for fast Total Variation proximity operators
Other
223 stars 60 forks source link

Weighted loss function #37

Closed david-castillo closed 6 years ago

david-castillo commented 7 years ago

Hi,

We try to solve a fused lasso problem with a slightly different loss function. In our case the loss function is weighted as

min ||Wi(Xi-Yi)||^2 + lambda ....

To adapt proxTV to this case, do you think it will be feasible and simple for us to build up a new proximity operator based on the function tautString_TV1_Weighted?

Regards

David

albarji commented 7 years ago

Hi David,

What is the equation for the whole objective function you want to optimize? If you want to use the usual regularizers for fused lasso, and you are only changing the loss function, then you solve that fairly easily. You can use a proximal gradient or a FISTA algorithm, providing a way of computing the gradient for your loss, and the tools you can find in proxTV for the Total-Variation + L1 regularizer.

david-castillo commented 7 years ago

Hi,

First to say is that I'm not an expert with the algorithms. The full function has the usual regularizers and looks like:

min ||Wi(Xi-Yi)||^2 + lambda Wi(Xi-Xj)

in two dimensions.

We are using now a graph fused lasso package from https://github.com/tansey/gfl which works well and can be used for different loss functions. The problem is that the algorithm used there ADMM seems to converge slowly and we are struggling to make it faster as we don't see how to parallelize it. Can we parallelize FISTA? If you have suggestions I'd appreciate.

Thanks

David

albarji commented 7 years ago

ADMM is indeed quite slow unless you find good stepsize parameters, which is no easy task in general. But this problem should be solvable through a combination of FISTA and the 2D Total Variation proximity solver provided in proxTV.

You see, FISTA is a general algorithm for minimizing a sum of a differentiable function (your loss) and a non-differentiable function. For the differentiable part you need to provide a way to compute the gradient of that part, which at first glance should be simple for your problem. For the non-differentiable part you need to provide the prox operator, but you already have that coded in proxTV!

About parallelization, FISTA essentialy alternates between calls to the gradient function and the prox operator you provided. 2D Total Variation here alredy admits parallelization, so half of your problem is already solved.

FISTA is not implemented in this library, though I created some scripts using it for the related paper. I'm attaching a general implementation of FISTA in MATLAB. As you will see, you need to provide gradient and prox functions approriate for your problem, as discussed above.

FISTAB.txt

Hope this helps!

david-castillo commented 7 years ago

It definitely helps a lot. I'll try to implement it and compare with what we have. Thanks