pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.34k stars 484 forks source link

mmap issue in bf16 of gpt-fast #165

Open yanbing-j opened 2 months ago

yanbing-j commented 2 months ago

gpt-fast will use torch.load with mmap=True to load checkpoints of models. This may help speed up model load time. However, eventually, mmap is not used in bf16, because in https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L247, model will to bfloat16 from float16 when running bf16 model. to will malloc a new memory area, mapped file is not used.

Meanwhile, in int8/int4, the logic of https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L247 does not make sense. int8 model should not convert to bfloat16 data type. Now, int8/int4 can work well, because weight is not a parameter of int8/int4 modules by chance.