I tried running ESM2 inference (model(seq_tokens).logits) in full or half-precision (model.half()) on Apple M3 Max Chip and torch==2.3.1.
I noticed that if I use half-precision the inference time is ~10x longer (while the memory drops as expected) - any idea why the runtime increases so drastically?
I tried running ESM2 inference (
model(seq_tokens).logits
) in full or half-precision (model.half()
) on Apple M3 Max Chip and torch==2.3.1. I noticed that if I use half-precision the inference time is ~10x longer (while the memory drops as expected) - any idea why the runtime increases so drastically?