facebookresearch / fastText

Library for fast text representation and classification.
https://fasttext.cc/
MIT License
25.83k stars 4.71k forks source link

loss increases with progress, not convergence #1291

Open jylei66 opened 2 years ago

jylei66 commented 2 years ago

I define a focal loss function which has 3 parts: computeOutput, backprop, and loss. When I train with this new focal loss, it happens loss increase with Progress(should decrease and convergence). Would you pls advise how to modify this focal loss function? In loss.cc, I added focal loss:

FocalLoss::FocalLoss(std::shared_ptr& wo) : Loss(wo) {}

void FocalLoss::computeOutput(Model::State& state) const { Vector& output = state.output; output.mul(*wo_, state.hidden); real max = output[0], z = 0; int32_t osz = output.size(); for (int32_t i = 0; i < osz; i++) { max = std::max(output[i], max); } for (int32_t i = 0; i < osz; i++) { output[i] = exp(output[i] - max); z += output[i]; } for (int32_t i = 0; i < osz; i++) { output[i] /= z; } }

real FocalLoss::forward( const std::vector& targets, int32_t targetIndex, Model::State& state, real lr, bool backprop) { computeOutput(state);

assert(targetIndex >= 0);
assert(targetIndex < targets.size());
int32_t target = targets[targetIndex];
real loss = 0.0;

if (backprop) {
    int32_t osz = wo_->size(0);
    real alpha = 0.0;
    for (int32_t i = 0; i < osz; i++) {
        bool labelIsPositive = utils::contains(targets, i);
        real pt = labelIsPositive ? state.output[target] : (1-state.output[target]);
        real pro = state.output[i];
        loss += -FOCAL_ALPHA * pow(1-pt, FOCAL_GAMMA) * std_log(pt);
        if (i != target) {
            alpha = FOCAL_ALPHA * pow(1 - pt, FOCAL_GAMMA - 1) *
                    (FOCAL_GAMMA * (-1 * pt * pro) * std_log(pt) + pro * (1 - pt));
        }
        else {
            alpha = FOCAL_ALPHA * pow(1 - pt, FOCAL_GAMMA) * (FOCAL_GAMMA * pt * std_log(pt) + pt - 1);
        }
        alpha *= lr;
        state.grad.addRow(*wo_, i, alpha);
        wo_->addVectorToRow(state.hidden, i, alpha);
    }
}
return loss;

}