awslabs / gluonts

Probabilistic time series modeling in Python
https://ts.gluon.ai
Apache License 2.0
4.53k stars 744 forks source link

custom gluonts framework with triplet loss #722

Open sy2657 opened 4 years ago

sy2657 commented 4 years ago

Hi, I was interested in implementing a loss function called a triplet loss function. I see that F is a parameter in the function of hybrid_forward in the definition of the custom TrainNetwork and PredNetwork. Where do I define F and how can i operate a loss on 3 inputs ?

Here is more detail about triplet loss function (https://arxiv.org/pdf/1901.10738.pdf) :

Screen Shot 2020-03-24 at 12 43 55 PM Screen Shot 2020-03-24 at 12 41 00 PM
StatMixedML commented 4 years ago

Very relevant topic. Hope it is being picked!

AaronSpieler commented 4 years ago

Hello,

have you come up with any solution?

F is just the data/tensor type, i.e. mx.NDarray or mx.Symbol. It's used to call the appropriate functions.

Fundamental difficulties that I see: