databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

ParallelDroplessMLP initialises self.mlp twice #83

Open 152334H opened 6 months ago

152334H commented 6 months ago

What the title says. In layers/dmoe.py:

class ParallelDroplessMLP(moe.ParallelMLP):

    def __init__(self, args : Arguments):
        super(ParallelDroplessMLP, self).__init__(args) # <-- first init!
        self.hidden_size = args.hidden_size
        self.ffn_hidden_size = mpu.features_per_rank(args)
        self.blocking = 128
        self.mlp = dmlp_registry.get(args) # <-- second init!

As a subclass of moe.ParallelMLP, ParallelDroplessMLP first initialises self.mlp in super().__init__() (at layers/moe.py):

class ParallelMLP(torch.nn.Module):

    def __init__(self, args : Arguments):
        # ... omitted ...

        # Expert MLP.
        self.mlp = mlp.MLP(args)

This causes extra initialisation time && init memory usage, as the weights created in this init are immediately overwritten by new weights created via self.mlp = dmlp_registry.get(args).

Apologies in advance if this double-init process is actually crucially important to the mechanics of the library; I personally did not observe anything breaking after commenting out the first initialisation.

tgale96 commented 6 months ago

Hi! This is a good point and we should probably fix it by refactoring weight initialization into a separate function that can be overloaded by the derived class. We would welcome a PR if you happen to have cycles to make the change! :)

cli99 commented 6 months ago

This double init does not seem to affect memory usage? I printed the memory allocation before and after https://github.com/stanford-futuredata/megablocks/blob/main/megablocks/layers/dmoe_test.py#L41, although the MLP init and mlp_impl init are both called, the allocated memory is still hidden intermediate num_experts * bytes per param + router.

152334H commented 6 months ago

not sure why that happened for you; I get clear and obvious reduced memory when I initialise dMoE with the first inner self.mlp initialisation commented out.

tgale96 commented 5 months ago

Do you happen to have a repro for the increased memory usage?

152334H commented 5 months ago

repro:

from megablocks.layers.dmoe import ParallelDroplessMLP
from megablocks.layers.moe import mlp
from megablocks.layers.arguments import Arguments
import os, psutil, sys
import torch

# mixtral-like arguments. some configs disabled for speed.
args = Arguments(
    #hidden_size=4096, #1024
    #ffn_hidden_size=14336, #4096
    #num_layers=num_layers,
    bias=False, # True
    return_bias=False,# True
    #activation_fn=torch.nn.functional.silu, #DEFAULT_ACTIVATION_FN

    # MoE arguments.
    moe_num_experts=8, #1
    moe_top_k=2, #1
    mlp_type='glu', # 'mlp'
    #mlp_impl='grouped', # 'sparse'
    device='cpu', # torch.cuda.current_device()
)

if sys.argv[1] == 'x': mlp.MLP = lambda a:None # inject and force replace mlp.MLP(args) with a no-op
m = ParallelDroplessMLP(args)

You can use time -v to track the peak memory of the python process, which is higher with self.mlp = ...:

$ /usr/bin/time -v python3 lol.py y 2>&1 | grep Maximum
        Maximum resident set size (kbytes): 803400
$ /usr/bin/time -v python3 lol.py x 2>&1 | grep Maximum
        Maximum resident set size (kbytes): 666412

I assume @cli99 did not see this because they were tracking the final memory usage, presumably after the first self.mlp is garbage collected.

tgale96 commented 5 months ago

Thanks for the repro! This should be relatively easy to fix once we get some free cycles to do the work.