Open crcrpar opened 1 day ago
LayerNorm vs RMSNorm GptNeoxMLP vs LlamaMLP
fwiw, the mlp used in stablecode is not benchmarked as per #742
# stablecode-completion-alpha-3b
GPT(
(lm_head): Linear(in_features=2560, out_features=49152, bias=False)
(transformer): ModuleDict(
(wte): Embedding(49152, 2560)
(h): ModuleList(
(0): Block(
(norm_1): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
(attn): CausalSelfAttention(
(attn): Linear(in_features=2560, out_features=7680, bias=True)
(proj): Linear(in_features=2560, out_features=2560, bias=True)
)
(post_attention_norm): Identity()
(norm_2): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
(mlp): GptNeoxMLP(
(fc): Linear(in_features=2560, out_features=10240, bias=True)
(proj): Linear(in_features=10240, out_features=2560, bias=True)
)
(post_mlp_norm): Identity()
)
)
(ln_f): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
)
)
# Llama-3-8B
GPT(
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
(transformer): ModuleDict(
(wte): Embedding(128256, 4096)
(h): ModuleList(
(0): Block(
(norm_1): RMSNorm()
(attn): CausalSelfAttention(
(attn): Linear(in_features=4096, out_features=6144, bias=False)
(proj): Linear(in_features=4096, out_features=4096, bias=False)
)
(post_attention_norm): Identity()
(norm_2): RMSNorm()
(mlp): LLaMAMLP(
(fc_1): Linear(in_features=4096, out_features=14336, bias=False)
(fc_2): Linear(in_features=4096, out_features=14336, bias=False)
(proj): Linear(in_features=14336, out_features=4096, bias=False)
)
(post_mlp_norm): Identity()
)
)
(ln_f): RMSNorm()
)
)
Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.
🐛 Bug
Let's take stablecode-completion-alpha-3b whose sequence length (Config.block_size) is 16384,
This command goes OOM while the same config (= model and compile) works with a single H100 with the memory usage of 77.02 GB.
For the sequence length of 16384
For the sequence length of 8192
When
--distributed_mode
is "fsdp", then the benchmark script choosesthunder.distributed.fsdp
for--compile
of thunder w/o dynamo keyword, and FSDP1 for the others.Clearly, FSDP2 & Thunder uses too much memory even compared to thunder's fsdp, while thunder's fsdp itslef seems to use more memory than Eager and Torch Compile. When I was on #940, I didn't see this trend of memory usage. Also, for Llama-3-8B, thunder still uses more memory but the gap is not that huge.
To Reproduce
Apply a diff like this and run commands like
Code sample
Expected behavior
Environment
pjnl-20240919
Additional context
related to #1175
cc @carmocca @crcrpar