bclarkson-code / Tricycle

Autograd to GPT-2 completely from scratch
104 stars 7 forks source link

Mixed precision support #84

Closed bclarkson-code closed 1 month ago

bclarkson-code commented 1 month ago

Currently, mixed precision is implemented but is disabled because training is unstable (exploding gradients).

The source of this instability should be investigated and fixed

kddubey commented 1 month ago

Hello. Awesome package!

Is loss scaling and de-scaling performed? Maybe including loss scaling while missing de-scaling (after upcasting gradients from fp16 to fp32) might cause exploding gradients

So this line for SGD should look something like—

        if TRICYCLE_CONTEXT.use_mixed_precision:
            tensor.array.grad = tensor.array.grad.astype(xp.float32) / scaling_factor

—where scaling_factor was the same number multiplied by the loss before running the backwards pass

bclarkson-code commented 1 month ago

Hi @kddubey, Thanks! I really appreciate both your compliment and your contribution.

I think you are absolutely right, loss scaling is something I missed. I’m away from my computer right now but I’ll implement it when I get back, unless you would like to make a pull request?

I suspect that loss scaling isn’t the only issue with the first draft of mixed precision training so the plan is to first write some tests to figure out where the issues are and then fix them until mixed precision is stable enough to be enabled.

kddubey commented 1 month ago

unless you would like to make a pull request?

tbh I'm kind of a PyTorch monkey, but I'm happy to at least take a look

Maybe you wanna consider adding a GradScaler object in the same way PyTorch does? scaler.update() dynamically adjusts the scaling factor across training iterations to automatically avoid underflow, which is a useful and maybe necessary abstraction. I double checked that precision-sensitive computations are already handled by Tricycle. So hopefully missing de-scaling is the cause of the issue.

bclarkson-code commented 1 month ago

That link from nvidia is exactly what we need, thanks!

As per that link you shared, we’ll definitely need to detect whether we are underflowing (percentage of 0’s in the gradient?) or overflowing (any nans or infs) and adjust the scaling factor as needed.

I think a stand-alone grad scaler object sounds like a nice idea. It might end up being neater to keep the logic in the generic Optimiser object, but we’ll see once it is implemented.

bclarkson-code commented 1 month ago

I added this test that is unstable with mixed precision training:

import logging

import numpy as np

from tricycle import TRICYCLE_CONTEXT
from tricycle.configs import DebugConfig
from tricycle.layers import Dense, Layer
from tricycle.loss import MeanSquaredError
from tricycle.optimisers import StochasticGradientDescent
from tricycle.tensor import Tensor
from tricycle.utils import UseMixedPrecision

logger = logging.getLogger(__name__)

class LongBoi(Layer):
    """
    A very deep MLP with no nonlinearities, designed to underflow in mixed
    precision training
    """

    def __init__(self, n_layers: int = 16):
        self.layers = [
            Dense(to_size=16, from_size=16, name=f"layer_{i}")
            for i in range(n_layers)
        ]

    def forward(self, tensor: Tensor) -> Tensor:
        for layer in self.layers:
            tensor = layer(tensor)
        return tensor

    def zero_grad(self):
        for layer in self.layers:
            layer.zero_grad()

    def update(self, optimiser):
        for layer in self.layers:
            layer.update(optimiser)

def test_can_train_in_mixed_precision():
    """
    Check that a model can be trained in mixed precision without overflowing

    We're using a very deep model with no nonlinearities that should cause
    gradient issues if mixed precision is broken
    """
    np.random.seed(0)
    learning_rate = 1e-3
    weight_decay = 1e-1
    model = LongBoi(64)

    loss_fn = MeanSquaredError()
    optimiser = StochasticGradientDescent(
        learning_rate=learning_rate, weight_decay=weight_decay, logger=logger
    )

    inputs = Tensor(
        np.random.random(
            (32, 16),
        ),
        is_batched=True,
        requires_grad=False,
    )
    outputs = Tensor(
        np.random.random(
            (32, 16),
        ),
        is_batched=True,
        requires_grad=False,
    )

    with UseMixedPrecision():
        first_loop = True
        for step in range(100):
            logits = model(inputs)
            loss = loss_fn(outputs, logits)
            loss.backward()
            loss = loss.numpy().item() / TRICYCLE_CONTEXT.loss_scale_factor
            if first_loop:
                # make sure we start with a big loss
                assert loss > 50
                first_loop = False
            logger.info(f"{loss=}, {TRICYCLE_CONTEXT.loss_scale_factor=}")
            model.update(optimiser)

        # make sure the loss has decreased as expected
        assert 7.5 < loss < 8

If TRICYCLE_CONTEXT.loss_scale_factor is set to 1, we overflow immediately.

bclarkson-code commented 1 month ago

But now i've added dynamic loss scaling, it no longer happens and GPT-2 seems to stably train in mixed precision too. Looks like you were right @kddubey, thanks for all the help!

kddubey commented 1 month ago

Wow, that was fast. Nice test, and excellent to hear that GPT-2 training in mixed precision is working

I (and I'm sure others) also appreciate this project's adherence to the issue-PR workflow instead of a bunch of straight-to-main commits (like I do lol). It makes learning from your work a neat process!