Open mlsbio opened 3 months ago
(oops, this is Sergey, I accidently posted from a different account) . haha
Edits I had to make: https://github.com/sokrypton/esm3/commit/f8a9f0dca35cd731f488c7830aa5c155fa91d14b
How are you running into these errors? Do you have a repro script? It should be running under an autocast context: https://github.com/evolutionaryscale/esm/blob/10077d8a8e120f632dee0ea25e68008c4993b535/esm/models/esm3.py#L529
There's also an example of invoking the raw forward function: https://github.com/evolutionaryscale/esm/blob/10077d8a8e120f632dee0ea25e68008c4993b535/examples/raw_forwards.py
How are you running into these errors? Do you have a repro script? It should be running under an autocast context:
There's also an example of invoking the raw forward function: https://github.com/evolutionaryscale/esm/blob/10077d8a8e120f632dee0ea25e68008c4993b535/examples/raw_forwards.py
Autocast is not called in encode or decode as far as I can tell, so when calling these you'll get a dtype error in the EncodeInputs class. Not sure what they changed to break it because both functions were working fine in July without this issue (I think the original weights were all float32?).
is there some global flag to make everything bfloat16. I recently had to go through the code an hardcode bfloat16... which seems kind of silly:
Otherwise, I was getting bfloat16 vs float errors