Currently running bf16 llama with long left-padded sequences produces wrong results. This is because we have an integer cache.offset and casting to bf16 causes overflows for large values.
Axon used to cast integers according to model policy, but that's no longer the case as of https://github.com/elixir-nx/axon/pull/547, so we can keep offset as an integer and it's all good.
Currently running bf16 llama with long left-padded sequences produces wrong results. This is because we have an integer
cache.offset
and casting to bf16 causes overflows for large values.Axon used to cast integers according to model policy, but that's no longer the case as of https://github.com/elixir-nx/axon/pull/547, so we can keep offset as an integer and it's all good.