datamol-io / graphium

Graphium: Scaling molecular GNNs to infinity.
https://graphium-docs.datamol.io/
Apache License 2.0
190 stars 12 forks source link

MuTransfer implementation of the MuAdam optimizer #128

Closed DomInvivo closed 1 year ago

DomInvivo commented 1 year ago

Ee want to implement uTransfer in our code-base to allow scaling our model's optimal parameters. However, I see that they require using the optimizer mup.MuAdam instead of torch.optim.Adam.

Do you think this could be supported on IPU's?

hatemhelal commented 1 year ago

this needs testing but I think this should work with passing poptorch.optim.Adam as the implementation: https://github.com/microsoft/mup/blob/main/mup/optim.py#L38

something like:

MuAdam(model.parameters(), impl=poptorch.optim.Adam)
callumm-graphcore commented 1 year ago

Hi Dominique, are you just interested in using the optimizers from the mup package, or are there other features of the package that you are also interested in using?

DomInvivo commented 1 year ago

Hi @callumm-graphcore , I need to use the mup.init.uniform_ and mup.MuReadout. But I assumed they would be supported natively?

callumm-graphcore commented 1 year ago

Hi Dominique, thanks for letting me know, I just wanted to know what the scope of the task was. I think I've managed to get this working on IPUs, and have attached my test script that demonstrates this below. Sorry for dumping 150 lines into a Github comment, but I can't upload .py files to Github. Please let me know if anything here is unclear.

import argparse

import mup
import poptorch
import torch
from torch import nn

torch.manual_seed(17244)

parser = argparse.ArgumentParser()
parser.add_argument('--ipu', help="Run on IPU", action="store_true")
parser.add_argument('--width-lower', help="Smallest width to test", type=int)
parser.add_argument('--width-upper', help="Largest width to test, inclusive", type=int)
parser.add_argument('--width-step', help="Step between widths to test", type=int)
parser.add_argument('--training-steps',
                    help="Number of training steps to run (must be >= 2)", type=int)
parser.add_argument('--lr', help="Learning rate", type=float)

args = parser.parse_args()

if args.training_steps < 2:
    raise ValueError("Must run at least 2 training steps")

BASE_WIDTH = 8
DELTA_WIDTH = 16

widths_to_test = list(range(
    args.width_lower,
    args.width_upper + args.width_step,
    args.width_step
))

class PopReadout(mup.MuReadout):

    """
    NOT a drop-in replacement for mup.MuReadout - you will need to
    pass output_mult = 1.0 / [width of your base model] (or divide
    your value of output_mult by the width of your base model if you
    are passing it already)
    """

    def __init__(self, in_features, *args, base_width=None, **kwargs):

        if base_width is None:
            raise ValueError("base_width must be specified in PopReadout")

        self.base_width = base_width

        self._absolute_width = float(in_features)

        super().__init__(in_features, *args, **kwargs)

    def width_mult(self):
        return self._absolute_width / self.base_width

class CIFAR10MLP(nn.Module):

    def __init__(self, width, ipu):

        super().__init__()

        self.width = width
        self.lin1 = nn.Linear(3*32*32, width)
        self.lin2 = nn.Linear(width, width)
        if ipu:
            self.lin3 = PopReadout(width, 10, base_width=BASE_WIDTH)
        else:
            self.lin3 = mup.MuReadout(width, 10)
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.loss = nn.NLLLoss()

    def forward(self, x, labels=None):

        x = torch.flatten(x, start_dim=1)
        preact1 = self.lin1(x)
        act1 = preact1.relu()
        preact2 = self.lin2(act1)
        act2 = preact2.relu()
        logits = self.lin3(act2)

        out = self.log_softmax(logits)

        if self.training:
            return (act1, act2, out), self.loss(out, labels)
        return x

base_model = CIFAR10MLP(width=BASE_WIDTH, ipu=False)
delta_model = CIFAR10MLP(width=DELTA_WIDTH, ipu=False)

examples = [
    (torch.rand(16, 3, 32, 32), torch.randint(high=10, size=(16,)))
    for _ in range(args.training_steps)
]

for width in widths_to_test:

    print(f"Testing width: {width}")

    model = CIFAR10MLP(width=width, ipu=args.ipu)
    mup.set_base_shapes(model, base_model, delta=delta_model)

    for param in model.parameters():
        mup.init.uniform_(param, -0.1, 0.1)

    model.train()

    if args.ipu:

        opts = poptorch.Options()
        optimizer = mup.MuAdam(model.parameters(), lr=args.lr, impl=poptorch.optim.Adam)
        poptorch_model = poptorch.trainingModel(model, options=opts, optimizer=optimizer)

        for index, (inp, label) in enumerate(examples):

            (act1, act2, out), _ = poptorch_model(inp, label)

            if index == 0:
                act1_at_t0 = act1
        act2_at_t0 = act2

            else:
                print(f"stddev(x_{index} - x_0): "
                      f"{(act1 - act1_at_t0).std().item()}, "
                      f"{(act2 - act2_at_t0).std().item()}")

    else:

        optimizer = mup.MuAdam(model.parameters(), lr=args.lr)

        for index, (inp, label) in enumerate(examples):

            optimizer.zero_grad()
            (act1, act2, out), loss = model(inp, label)
            loss.backward()
            optimizer.step()

            if index == 0:
                act1_at_t0 = act1
                act2_at_t0 = act2

            else:
                print(f"stddev(x_{index} - x_0): "
                      f"{(act1 - act1_at_t0).std().item()}, "
                      f"{(act2 - act2_at_t0).std().item()}")
DomInvivo commented 1 year ago

Awesome @callumm-graphcore , I'll check that in more details and try to implement it next week!

callumm-graphcore commented 1 year ago

Hi Dominique,

I've made some improvements to the script. In particular, I've fixed the docstring for the PopReadout class and made it so you can compare results between IPU and CPU. I have uploaded the script as a Github gist. Please let me know if anything is unclear.

With thanks, Callum Graphcore

DomInvivo commented 1 year ago

Hey @callumm-graphcore , I don't understand what's the point of PopReadout? Why do you need the additional base_width? Why doesn't the MupReadout work out of the box?

callumm-graphcore commented 1 year ago

Hi Dominique, the current implementation of the MupReadout class relies on an infshape attribute of the weight tensor being set to return the width multiplier in the width_mult method. However, this attribute gets lost somewhere along the way in PopTorch compilation (I don't currently know why this happens or whether it is avoidable). This means that trying to use MupReadout out of the box gives me an AttributeError.

All PopReadout does is override this method to return the correct value in a way that doesn't require the infshape to be known, and add some extra logic to __init__ to make this possible. This works as long as the shape of the layer doesn't change.

It seems reasonable to me to assume that the shape of a layer isn't going to change and so width_mult can return a fixed value. Is this a flawed assumption? Maybe we could try and upstream this change into the mup repo?

DomInvivo commented 1 year ago

This work-around is problematic since the infshape is also necessary for every Linear layers, since their initialization must also change.

Could you please check my implementation on this new pull-request #144 from the branch utransfer? Basically, I replaced every linear layer by the goli class FCLayer, which uses the method reset_parameters to compute the set_base_shapes. Further, using the parameter is_readout_layer during the initialization of FCLayer switches the Linear to a MuReadout.

This seems to work without errors, but I haven't checked yet if the results on IPU vs CPU match exactly. To check this, you can either use your own code with the FCLayer, or use the file goli/ipu/ipu_simple_lightning.py, which implements a minimally working version of the IPU + lightning.

@callumm-graphcore , can you please check whether the results match using my implementation? And whether the CPU / IPU match? And try to implement the poptorch.optim.Adam in that file?

callumm-graphcore commented 1 year ago

Hi Dominique, I will take a look at the pull request and check that the IPU and CPU results match today.

If I'm not mistaken, the infshape attribute is only necessary for the Linear layers at initialisation, which happens on the CPU and so the attribute doesn't need to be carried through anything that happens on the IPU. Is there a situation where you might need the infshape attribute after initialisation?

DomInvivo commented 1 year ago

No, I don't think I'll ever need the infshape on IPU since, as you said, all layers are initialized on CPU

callumm-graphcore commented 1 year ago

In that case, are you happy with this workaround? If so, I am working on a PR built off of #140 that will add this workaround as an option in FCLayer which should be ready by the end of the day.

DomInvivo commented 1 year ago

@callumm-graphcore , you're talking about implementing the base_width _absolute_width and width_mult to the FCLayer, in case someone decides to rescale them on IPU? Yeah that would make sense.

I think you referenced the wrong PR, it should be #144

callumm-graphcore commented 1 year ago

Sorry, you're right, I meant #144.

DomInvivo commented 1 year ago

@callumm-graphcore , I also noticed that the model diverge quite massively when running any model since the implementation of mup, both on CPU and IPU. The loss starts at ~10^5 and explodes very quickly to 10^11.

I am testing now the file main_run_multitask.py with the configs CONFIG_FILE = "expts/configs/config_ipu_allsizes.yaml".

To avoid the explosion, I need to set the gain parameter of every call to xavier_uniform_ to be 0.5, but then, after increasing the number of layers, I need to decrease to 0.2. So the gain depends on the number of layers.

This is a very undesirable property. Since you're more familiar with the mup repo by now, can you investigate why it explodes and try to avoid it, without the need to set the gain parameter?

callumm-graphcore commented 1 year ago

Hi Dominique, I'm happy to take a look at this, but it could very well take some time and will take time away from other things. I've reproduced similar behaviour to what you're describing. Some initial thoughts:

Could you please also explain why changing the gain parameter is so undesirable?

DomInvivo commented 1 year ago

Because the gain will depend on the depth and width of the model. But you're right, I forgot to implement the 1/d-scaled attention.

DomInvivo commented 1 year ago

Regarding regression vs classification, I don't think it matters much because the initial values are so high, that it would also saturate the softmax

callumm-graphcore commented 1 year ago

I've hacked 1/d-scaled attention into torch.nn.MultiheadAttention in my venv and I'm still seeing the issue. The activations appear to increase exponentially with depth in the first forward pass, so this would seem to be purely an issue with initialisation. I will continue to investigate.

DomInvivo commented 1 year ago

I tried the same thing for the 1/d scaling with the same results.

However, when setting the yaml parameter architecture: gnn: layer_kwargs: attn_type: to "none" instead of "full-attention", it still diverges with large depth, but much less. So there might be an issue with the MultiheadAttentionMup that I implemented.

Do you have an example implementation of the MultiheadAttention layer using mup?

callumm-graphcore commented 1 year ago

Do you have an example implementation of the MultiheadAttention layer using mup?

No, sadly not

callumm-graphcore commented 1 year ago

In the config you're running, it looks like you're not using normalisation in the GPS layer - is that deliberate?

DomInvivo commented 1 year ago

Ohhh that's a mistake! No more divergence when we add batch_norm. Good catch!

callumm-graphcore commented 1 year ago

I'm glad that worked, but I still have two concerns:

Even if we fix the first of these issues, we'll still have to re-implement reset_parameters with mup support for every new layer we want to implement. I think it'd be best if we took a step back and re-approached this from first principles. Ideally, we'd be able to use set_base_shapes as originally intended (by creating another model with a smaller hidden dim and using this to set the base shapes), which would prevent us from having to do this for each of our layers individually and also eliminate the second problem.

There's at least two ways to implement this:

The first of these would be more straightforward to use but would be less flexible than the second option.

I can submit a PR implementing muTransfer in this manner if you would like. We should definitely still keep your MultiheadAttentionMup implementation, though. What do you think? I am very happy to have a more involved discussion about any of this if it is unclear.