ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
15.01k stars 855 forks source link

Additoinal losses #336

Closed Jyun1998 closed 4 months ago

Jyun1998 commented 5 months ago

Proposed changes

Added commonly used losses with tests

Checklist

Put an x in the boxes that apply.

Jyun1998 commented 5 months ago

Hi @awni , rebased and make cos loss to remain on this pr.

Also commented on #324 based on my implementation.

Thanks :)

awni commented 5 months ago

Hey @Jyun1998 sorry but somehow my main comment did not get included. I must have not pushed the save button by accident.

Basically I'm wondering what reference you use for this loss as it looks quite different than the similarly named cosine similarity loss in PyTorch ?

For example there is a target and a margin both of which I do not expect. Unless there is a good reason for the difference, I would suggest we follow the PyTorch implementation as a reference.

awni commented 4 months ago

@Jyun1998 are you still planning to follow up on this?

Jyun1998 commented 4 months ago

@Jyun1998 are you still planning to follow up on this?

https://github.com/keras-team/keras/blob/v2.14.0/keras/losses.py#L1162-L1236

Hi awni, according to common tensorflow and pytorch implementaiton, the functions are composed of doing l2 norm to each embedding and returns the negative of dot product of both embedding.

My codes also do so, and margin-based loss is applying if it's necessary :)

awni commented 4 months ago

@Jyun1998 got it. We should keep it simple until we see that we need more features. Could you follow the PyTorch cosine similarity loss? I think that one covers the most common case (margin / targets are are niche and possibly never needed so we don't want to add them to the API until we are sure they are necessary)

Jyun1998 commented 4 months ago

@Jyun1998 got it. We should keep it simple until we see that we need more features. Could you follow the PyTorch cosine similarity loss? I think that one covers the most common case (margin / targets are are niche and possibly never needed so we don't want to add them to the API until we are sure they are necessary)

I also agree. Even though there's also margin for pytorch F.cosine_similarity, it defaultly do not use it.

There's only slight change needed for changes and I will test the function and commit asap. Thanks for the review

awni commented 4 months ago

Even though there's also margin for pytorch F.cosine_similarity

I don't see the margin in the docs? Is it in the source code?

Screenshot 2024-01-07 at 6 30 32 AM
Jyun1998 commented 4 months ago

Even though there's also margin for pytorch F.cosine_similarity

I don't see the margin in the docs? Is it in the source code?

Screenshot 2024-01-07 at 6 30 32 AM

https://pytorch.org/docs/stable/generated/torch.nn.CosineEmbeddingLoss.html#torch.nn.CosineEmbeddingLoss

awni commented 4 months ago

I see, thanks. Yes let's go with the plain cosine similarity for now. Thank you!

awni commented 4 months ago

Also could you rebase and resolve conflicts?

Jyun1998 commented 4 months ago

Also could you rebase and resolve conflicts?

Am I correct that losses test codes are gone?


nvm found new losses test file

Jyun1998 commented 4 months ago

Also could you rebase and resolve conflicts?

Could you check? Thanks :)