interesaaat / TorchSharp

.NET bindings for the Pytorch engine
MIT License
17 stars 1 forks source link

Loss function class #17

Closed artidoro closed 5 years ago

artidoro commented 5 years ago

In Pytorch loss functions in torch.nn work as follows:

loss = NLLLoss(SOME PARAMETERS IF NEEDED)
prediction = model.forward(input)
output = loss(prediction, target)

We should match the behavior where once initialized the loss can be applied to the prediction and the target tensors.

interesaaat commented 5 years ago

I am thinking about providing an additional method to the loss class whereby you can specify parameters for the loss function you want to use and a lambda function is returned (instead of an object). Do you think that this could work?

This is the only way we can actually implement loss(prediction, target) in C#

artidoro commented 5 years ago

Yes, I think that should work!