evolutionaryscale / esm

Other
1.23k stars 139 forks source link

global flag for bfloat16? #70

Open mlsbio opened 3 months ago

mlsbio commented 3 months ago

is there some global flag to make everything bfloat16. I recently had to go through the code an hardcode bfloat16... which seems kind of silly:

image

Otherwise, I was getting bfloat16 vs float errors

mlsbio commented 3 months ago

(oops, this is Sergey, I accidently posted from a different account) . haha

sokrypton commented 3 months ago

Edits I had to make: https://github.com/sokrypton/esm3/commit/f8a9f0dca35cd731f488c7830aa5c155fa91d14b

ebetica commented 2 months ago

How are you running into these errors? Do you have a repro script? It should be running under an autocast context: https://github.com/evolutionaryscale/esm/blob/10077d8a8e120f632dee0ea25e68008c4993b535/esm/models/esm3.py#L529

There's also an example of invoking the raw forward function: https://github.com/evolutionaryscale/esm/blob/10077d8a8e120f632dee0ea25e68008c4993b535/examples/raw_forwards.py

lhallee commented 1 month ago

How are you running into these errors? Do you have a repro script? It should be running under an autocast context:

https://github.com/evolutionaryscale/esm/blob/10077d8a8e120f632dee0ea25e68008c4993b535/esm/models/esm3.py#L529

There's also an example of invoking the raw forward function: https://github.com/evolutionaryscale/esm/blob/10077d8a8e120f632dee0ea25e68008c4993b535/examples/raw_forwards.py

Autocast is not called in encode or decode as far as I can tell, so when calling these you'll get a dtype error in the EncodeInputs class. Not sure what they changed to break it because both functions were working fine in July without this issue (I think the original weights were all float32?).