mpezeshki / pytorch_forward_forward

Implementation of Hinton's forward-forward (FF) algorithm - an alternative to back-propagation
MIT License
1.44k stars 139 forks source link

Why use Softplus function in loss? #10

Open FengChendian opened 1 year ago

FengChendian commented 1 year ago

In the train function, your code is a softplus function. $$loss = ln(1 + e^x)$$ But in The Forward-Forward Algorithm: Some Preliminary Investigations, Hinton uses logistic function. $$p = \sigma(\Sigma y^2 - \theta)$$ Here, $$\sigma(x) = \frac{1}{1 + e^{-x}}$$

Is this a mistake or a better choice?

    def train(self, x_pos, x_neg):
        for i in tqdm(range(self.num_epochs)):
            g_pos = self.forward(x_pos).pow(2).mean(1)
            g_neg = self.forward(x_neg).pow(2).mean(1)
            # The following loss pushes pos (neg) samples to
            # values larger (smaller) than the self.threshold.
            loss = torch.log(1 + torch.exp(torch.cat([
                -g_pos + self.threshold,
                g_neg - self.threshold]))).mean()
            self.opt.zero_grad()
            # this backward just compute the derivative and hence
            # is not considered backpropagation.
            loss.backward()
            self.opt.step()
        return self.forward(x_pos).detach(), self.forward(x_neg).detach()
taomanwai commented 5 months ago

I observed followings:

  1. Both serve the same purpose, maximize g_pos and minimize g_neg.
  2. Original FF gain/loss range is (0, 1) with near-zero gradients around both 0 and 1 gain/loss, while loss of current mpezeshki implementation (SoftPlus) is in range (0, +inf) with near-zero gradient around 0 loss but roughly constant gradient across the whole positive loss (including positive infinity).

I believe the constant gradient provides stable and sufficient amount of gradient to facilitate stable and relatively fast learning of weights, so mpezeshki use SoftPlus instead.

@mpezeshki , can you verify and confirm on the explanation mentioned?