google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

[Feature] Triplet Marginal Loss #1118

Open cvnad1 opened 1 month ago

cvnad1 commented 1 month ago

I would like to add the Triplet Marginal Loss Function.

Reference: https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#triplet_margin_loss

vroulet commented 1 month ago

Hello @cvnad1,

Feel free to proceed and add it to the self_supervised folder of the losses, so here

cvnad1 commented 4 weeks ago

@vroulet Hey sorry for a little delay

I completed the code for the loss function and am currently adding tests. Once the tests are added, I will create a PR.

Thanks for giving me the location where to add !!!