deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.07k stars 648 forks source link

How to define BPRLoss and MarginLoss by DJL? Does there have some samples? #1223

Open dxjjhm opened 2 years ago

dxjjhm commented 2 years ago

How to define BPRLoss and MarginLoss by DJL? Does there have some samples?

For PairWise data in recommender systems, like user, pos_item, neg_item

I am already see the source of Loss in DJL, but still don't kown how to do it.

In pytorch, loss could be compute by extends nn.module, but in DJL loss is extended from Evaluator not from AbstractBlock ? Would some experts tell me how to do?

zachgk commented 2 years ago

You want to extend the Loss class. It should be fairly easy. You only need to define the evaluate method. As a simple example, try looking at the L2Loss.

In evaluate(...), you will need to pass the PairWise data using NDLists. The prediction NDList should be the data produced by your model and the label is the true answer. If you have multiple pieces of data for either of them, they can either be stored in the same NDArray or by using the NDList like a tuple.