facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

training ESM-2 using bf16 or fp16 or tf32 or fp32? #259

Closed zhenyuhe00 closed 2 years ago

zhenyuhe00 commented 2 years ago

Hi, Congrats on your great work on ESM-2. I wonder what float point you used in training ESM-2, since I find using fp16 and using bf16 yields different numerical numbers as follows:

bf16['representations'][1] tensor([[ 0.0825, -0.1167, -0.0654, ..., -0.1011, -0.0023, -0.1523], [ 0.0603, -0.0815, 0.0459, ..., -0.0557, 0.0320, -0.1641], [ 0.1045, -0.0854, -0.0698, ..., -0.1143, 0.1016, -0.1533], ..., [ 0.0464, -0.0386, -0.0425, ..., -0.0144, 0.1045, -0.0649], [ 0.0273, -0.1060, -0.1289, ..., -0.0115, 0.1157, -0.0693], [ 0.0273, -0.1060, -0.1289, ..., -0.0115, 0.1157, -0.0693]], dtype=torch.bfloat16) fp16['representations'][1] tensor([[ 0.0814, -0.1161, -0.0657, ..., -0.1006, -0.0029, -0.1522], [ 0.0607, -0.0811, 0.0446, ..., -0.0558, 0.0321, -0.1647], [ 0.1031, -0.0839, -0.0707, ..., -0.1141, 0.1027, -0.1550], ..., [ 0.0468, -0.0378, -0.0432, ..., -0.0145, 0.1046, -0.0642], [ 0.0263, -0.1041, -0.1304, ..., -0.0126, 0.1151, -0.0701], [ 0.0264, -0.1042, -0.1304, ..., -0.0125, 0.1152, -0.0698]], dtype=torch.float16)

tomsercu commented 2 years ago

ESM-2 was trained in fp16. The examples you show seem totally reasonable numerical precision deviations

zhenyuhe00 commented 2 years ago

Thanks!