Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.03k stars 3.36k forks source link

Training a simple XOR network yields incorrect, undeterministic behaviour #18975

Closed Fohlen closed 10 months ago

Fohlen commented 10 months ago

Bug description

Hi, I am trying to train a simple DNN to solve the XOR problem. This can be trivially solved with a pure torch implementation. I cannot replicate the same simple model in lightning. Instead the trained model oscillates between different states, never managing to correctly produce XOR.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

# import libraries
import torch
import torch.nn as nn
from torch.autograd import Variable

class XOR(nn.Module):
    def __init__(self):
        super(XOR, self).__init__()
        self.linear_sigmoid_stack = nn.Sequential(
            nn.Linear(2, 2),
            nn.Sigmoid(),
            nn.Linear(2, 1)
        )

    def forward(self, x):
        return self.linear_sigmoid_stack(x)

if __name__ == "__main__":
    # create data
    Xs = torch.Tensor([[0., 0.],
                       [0., 1.],
                       [1., 0.],
                       [1., 1.]])

    y = torch.Tensor([0., 1., 1., 0.]).reshape(Xs.shape[0], 1)

    xor_network = XOR()

    epochs = 1000
    mseloss = nn.MSELoss()
    optimizer = torch.optim.Adam(xor_network.parameters(), lr=0.03)
    all_losses = []
    current_loss = 0
    plot_every = 50

    for epoch in range(epochs):

        # input training example and return the prediction
        yhat = xor_network.forward(Xs)

        # calculate MSE loss
        loss = mseloss(yhat, y)

        # backpropogate through the loss gradiants
        loss.backward()

        # update model weights
        optimizer.step()

        # remove current gradients for next iteration
        optimizer.zero_grad()

        # append to loss
        current_loss += loss
        if epoch % plot_every == 0:
            all_losses.append(current_loss / plot_every)
            current_loss = 0

        # print progress
        if epoch % 500 == 0:
            print(f'Epoch: {epoch} completed')

I tried to use Lightning to simplify away the boilerplate code like so:

import torch
from torch import nn
import torch.nn.functional as F
import lightning as L
from torch.utils.data import TensorDataset, DataLoader

class XORNetwork(L.LightningModule):
    def __init__(self):
        super(XORNetwork, self).__init__()
        self.linear_sigmoid_stack = nn.Sequential(
            nn.Linear(2, 2),
            nn.Sigmoid(),
            nn.Linear(2, 1)
        )

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch
        yhat = self.forward(x)
        loss = F.mse_loss(yhat, y)
        return loss

    def forward(self, x):
        return self.linear_sigmoid_stack(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

if __name__ == "__main__":
    X = torch.Tensor([[0., 0.], [0., 1.], [1., 0], [1., 1]])
    labels = torch.Tensor([0., 1., 1., 0])
    dataset = TensorDataset(X, labels)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

    xor_network = XORNetwork()

    # train model
    trainer = L.Trainer(max_epochs=500, accelerator="cpu")
    trainer.fit(model=xor_network, train_dataloaders=dataloader)

    xor_network.eval()
    with torch.no_grad():
        test_output = xor_network(X)
        print(test_output.round())


### Error messages and logs

_No response_

### Environment

<details>
  <summary>Current environment</summary>

* CUDA:
    - GPU:               None
    - available:         False
    - version:           None
* Lightning:
    - lightning:         2.1.1
    - lightning-utilities: 0.9.0
    - pytorch-lightning: 2.1.1
    - torch:             2.1.0
    - torchmetrics:      1.2.0
* Packages:
    - aiohttp:           3.8.6
    - aiosignal:         1.3.1
    - async-timeout:     4.0.3
    - attrs:             23.1.0
    - certifi:           2023.7.22
    - charset-normalizer: 3.3.2
    - filelock:          3.13.1
    - frozenlist:        1.4.0
    - fsspec:            2023.10.0
    - idna:              3.4
    - jinja2:            3.1.2
    - lightning:         2.1.1
    - lightning-utilities: 0.9.0
    - markupsafe:        2.1.3
    - mpmath:            1.3.0
    - multidict:         6.0.4
    - networkx:          3.2.1
    - numpy:             1.26.1
    - packaging:         23.2
    - pip:               22.3.1
    - pytorch-lightning: 2.1.1
    - pyyaml:            6.0.1
    - requests:          2.31.0
    - setuptools:        65.5.1
    - sympy:             1.12
    - torch:             2.1.0
    - torchmetrics:      1.2.0
    - tqdm:              4.66.1
    - typing-extensions: 4.8.0
    - urllib3:           2.0.7
    - wheel:             0.38.4
    - yarl:              1.9.2
* System:
    - OS:                Darwin
    - architecture:
        - 64bit
        - 
    - processor:         arm
    - python:            3.10.13
    - release:           23.0.0
    - version:           Darwin Kernel Version 23.0.0: Fri Sep 15 14:41:34 PDT 2023; root:xnu-10002.1.13~1/RELEASE_ARM64_T8103

</details>

### More info

_No response_
awaelchli commented 10 months ago

@Fohlen In your Lightning code,

  1. You didn't choose the same learning rate. Make it 0.03 in both cases.
  2. You didn't run for the same number of epochs. Make it 1000 in both cases.

And in the raw PyTorch code you are missing the test code:

    xor_network.eval()
    with torch.no_grad():
        test_output = xor_network(Xs)
        print(test_output.round())

To make both of them the same, the hyperparameters need to be the same of course. Can you try again? I get the correct predictions (i.e. 0 1 1 0) after these fixes.

In addition, to make it fully deterministic you can set the seed

L.seed_everything(0)
Fohlen commented 10 months ago

Hi @awaelchli, this indeed produces the correct result. I can get the code to converge correctly within 100 epochs or less with pure Torch, any idea why that wouldn't be the case with lightning?

awaelchli commented 10 months ago

I can get the code to converge correctly within 100 epochs or less with pure Torch

The code that you posted can't actually converge in 100 epochs. Please share what you changed to make that possible.

Fohlen commented 10 months ago

Sorry for the imprecise wording. After some experimentation with epochs I could produce the correct result at epoch=250 (not convergence). However, this appears to be extremely sensitive to the seed one uses when training. I find this interesting. According to the Deep Learning book, the correct weights should be learned with a single pass of this network. However, this behaviour is not lightning-specific. Thanks for your help, I will keep on digging in torch to find out the reason for this behaviour 👍