Sanaxen / cpp_torch

It is tiny-dnn based on libtorch. Only headers without dependencies other than libtorch, deep learning framework
MIT License
34 stars 11 forks source link

How do I define the loss function?(BCEWithLogitsLoss) #1

Closed Sanaxen closed 5 years ago

Sanaxen commented 5 years ago

BCEWithLogitsLoss ?

How can I define the loss function? Please tell me if you know.

Sanaxen commented 5 years ago

if (!o.requires_grad()) o.set_requires_grad(true);

OK?

ex) inline torch::Tensor safe_log(torch::Tensor x) { torch::Tensor y = log(abs(x) + 1.0e-12); return y; }

torch::Tensor BCEWithLogitsLoss(torch::Tensor o, torch::Tensor t)
{
    //Specify the target of automatic differentiation with True
    if (!o.requires_grad()) o.set_requires_grad(true);
    //o.options().requires_grad(true);

    torch::Tensor y = -(t*safe_log(torch::sigmoid(o)) + (1 - t)*safe_log(1 - torch::sigmoid(o))).mean();

    return y;
}