keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
758 stars 227 forks source link

Keep rope at float32 precision #1497

Closed grasskin closed 5 months ago

tirthasheshpatel commented 6 months ago

@mattdangerw Gemma still downcasts the tensors to compute_dtype since it uses it's own implementation of RoPE. I can submit a follow-up PR to use this layer instead.

mattdangerw commented 6 months ago

I think we also need to cast back to the compute_dtype before returning. Only the computation part needs to happen in float32.

Yeah looks like this is causing test failures, probably due to this issue?

danielhanchen commented 6 months ago

Hi :) I'm assuming this came about from my Twitter thread https://twitter.com/danielhanchen/status/1765446273661075609 :)

I added a fix into transformers 4.38.2 here: https://github.com/huggingface/transformers/pull/29285. So using mixed_bfloat16 causes torch.autocast to cast all ops to bfloat16. I don't normally use Keras, so unsure if torch.autocast affects operations, since I know even explicitly forcing float32 causes autocast to override it. However, unsure on Keras.

Also another problematic line is https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/gemma_attention.py#L159

        seq_len = ops.shape(x)[1]
        start_index = cache_update_index
        positions = ops.cast(
            ops.arange(seq_len, dtype="float32"), >>>>> self.compute_dtype <<<<
        )
        positions = positions + ops.cast(start_index, self.compute_dtype)

Which is wrong - Assume if someone did RoPE Scaling with float16 - this will cause 65504 to be the maximum, which in turn causes overflow ie infinities to occur. bfloat16 loses precision, but can represent larger numbers.

grasskin commented 5 months ago

Hi @danielhanchen, enjoyed reading the blogpost it was a great in depth dive!

Switched all of RoPE to happen in "float32" and added downcasting before returning. This likely works until we replace the call with normal Keras RoPE?

tirthasheshpatel commented 5 months ago

Switched all of RoPE to happen in "float32" and added downcasting before returning. This likely works until we replace the call with normal Keras RoPE?

Yeah, this should be good enough for now. We can merge this and I can rebase my PR on top of your changes.