aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
465 stars 154 forks source link

tf.aliasing support #1026

Open steeve opened 3 weeks ago

steeve commented 3 weeks ago

Hi,

We (@zml) found that tf.aliasing support seemed to be not working as expected, with the model producing garbage when used. In our case Llama 3.1 8B. This is problematic for transformer models because we leverage donations for the KvCache.

For now we're not emitting those attributes when on neuron, but we're not sure what to do as we feel that if the SDK doesn't support them, it should just ignore them right ?

The llama implementation is attached.

Thank you !

llama.aliasing.mlir.txt

nalwayaakshay commented 2 weeks ago

Can you try to use jax.buffer_donor rather than tf.aliasing_output in order to annotate donated buffers.

For example: %arg2: tensor<...> {jax.buffer_donor = true, mhlo.layout_mode = "default", mhlo.sharding = "{devices=[1,1,32,1]<=[32]}"} loc("state.kv_cache[0]['cached_key']")

From your .txt file: %arg291: tensor<256xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}", tf.aliasing_output = 0 : i32}

steeve commented 2 weeks ago

TIL jax.buffer_donor. Unfortunately, it doesn't work either. That being said, while the output is wrong, the tok/s doesn't change with donation, which is weird ?

devesr-amzn commented 2 weeks ago

Can you provide steps to reproduce the issue along with versions of dependencies in use (versions of neuronx-cc, libneuronxla)?

steeve commented 2 weeks ago

Packages:

neuronx-cc==2.15.141.0+d3cfc8ca
libneuronxla==2.0.4986.0

Checkout this branch: https://github.com/zml/zml/tree/steeve/synapse

Run the llama example with neuron:

$ cd zml/examples
$ ./bazel.sh run -c opt //llama:Llama-3.1-8B-Instruct --@zml//runtimes:cpu=false --@zml//runtimes:neuron=true

You can re-enable donations by commenting out those lines: https://github.com/zml/zml/blob/steeve/synapse/zml/module.zig#L301-L303