huggingface / nanotron

Minimalistic large language model 3D-parallelism training
Apache License 2.0
1.14k stars 107 forks source link

Add FLOPs equations for Mamba and fix number of parameters #111

Closed staghado closed 6 months ago

staghado commented 6 months ago

This PR does two things :

  1. Fix the number of parameters calculation in create_config_mamba.py
  2. Add model_flops_per_s in get_flops_per_sec method of the MambaModel class.

Here is a simple comparison between the formulae used in get_flops_per_sec(Estimated FLOPs) and the PyTorch flop_counter (Exact FLOPs) utility for different model sizes. The x axis is the d_model and the y axis is TFLOPS.

Screenshot 2024-03-17 at 20 35 41
3outeille commented 6 months ago

LGTM ! thank's for your contribution !