Open 152334H opened 6 months ago
Hi! This is a good point and we should probably fix it by refactoring weight initialization into a separate function that can be overloaded by the derived class. We would welcome a PR if you happen to have cycles to make the change! :)
This double init does not seem to affect memory usage? I printed the memory allocation before and after https://github.com/stanford-futuredata/megablocks/blob/main/megablocks/layers/dmoe_test.py#L41, although the MLP init and mlp_impl init are both called, the allocated memory is still hidden intermediate num_experts * bytes per param + router.
not sure why that happened for you; I get clear and obvious reduced memory when I initialise dMoE with the first inner self.mlp initialisation commented out.
Do you happen to have a repro for the increased memory usage?
repro:
from megablocks.layers.dmoe import ParallelDroplessMLP
from megablocks.layers.moe import mlp
from megablocks.layers.arguments import Arguments
import os, psutil, sys
import torch
# mixtral-like arguments. some configs disabled for speed.
args = Arguments(
#hidden_size=4096, #1024
#ffn_hidden_size=14336, #4096
#num_layers=num_layers,
bias=False, # True
return_bias=False,# True
#activation_fn=torch.nn.functional.silu, #DEFAULT_ACTIVATION_FN
# MoE arguments.
moe_num_experts=8, #1
moe_top_k=2, #1
mlp_type='glu', # 'mlp'
#mlp_impl='grouped', # 'sparse'
device='cpu', # torch.cuda.current_device()
)
if sys.argv[1] == 'x': mlp.MLP = lambda a:None # inject and force replace mlp.MLP(args) with a no-op
m = ParallelDroplessMLP(args)
You can use time -v
to track the peak memory of the python process, which is higher with self.mlp = ...
:
$ /usr/bin/time -v python3 lol.py y 2>&1 | grep Maximum
Maximum resident set size (kbytes): 803400
$ /usr/bin/time -v python3 lol.py x 2>&1 | grep Maximum
Maximum resident set size (kbytes): 666412
I assume @cli99 did not see this because they were tracking the final memory usage, presumably after the first self.mlp
is garbage collected.
Thanks for the repro! This should be relatively easy to fix once we get some free cycles to do the work.
What the title says. In
layers/dmoe.py
:As a subclass of
moe.ParallelMLP
,ParallelDroplessMLP
first initialisesself.mlp
insuper().__init__()
(atlayers/moe.py
):This causes extra initialisation time && init memory usage, as the weights created in this init are immediately overwritten by new weights created via
self.mlp = dmlp_registry.get(args)
.Apologies in advance if this double-init process is actually crucially important to the mechanics of the library; I personally did not observe anything breaking after commenting out the first initialisation.