aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
444 stars 148 forks source link

neuronx-cc compile crashes on `tril(x)` #659

Closed xanderdunn closed 1 year ago

xanderdunn commented 1 year ago

I am attempting to implement the triangle lower function in XLA HLO. This is used for creating the self attention mask, for example see here.

Attached is the xla hlo .pb generated by my code: rust_hlo_tril.pb.zip

Here is a debug_ir representation of it:

rust debug_ir: HloModule xla_computation_ordered_wrapper, entry_computation_layout={(f32[3,5,7]{0,1,2})->(f32[3,5,7]{0,1,2})}

ENTRY main.13 {
  iota.2 = s32[5]{0} iota(), iota_dimension=0
  broadcast.3 = s32[5,7]{0,1} broadcast(iota.2), dimensions={0}
  iota.4 = s32[7]{0} iota(), iota_dimension=0
  broadcast.5 = s32[5,7]{0,1} broadcast(iota.4), dimensions={1}
  compare.6 = pred[5,7]{0,1} compare(broadcast.3, broadcast.5), direction=GE
  broadcast.7 = pred[3,5,7]{0,1,2} broadcast(compare.6), dimensions={1,2}
  Arg_0.1 = f32[3,5,7]{0,1,2} parameter(0)
  constant.8 = f32[] constant(0)
  broadcast.9 = f32[3,5,7]{0,1,2} broadcast(constant.8), dimensions={}
  select.10 = f32[3,5,7]{0,1,2} select(broadcast.7, Arg_0.1, broadcast.9)
  ROOT tuple.11 = (f32[3,5,7]{0,1,2}) tuple(select.10)
}

And here is what it looks like in my code:

/// Return the lower triangle of the input tensor x.
/// All other values are zeroed out.
/// Generally this is invoked with diagonal=0
/// This assumes x is 3D
pub fn tril(x: &Tensor, diagonal: i64) -> Tensor {
    assert!(x.rank() == 3, "This is currently hard-coded to work only for matrices of rank 3. Generalizing it is a TODO.");
    let k: u64 = *x.shape().last().unwrap();
    let j = x.shape()[x.shape().len() - 2];
    let left_iota = iota(&[j], diagonal, &x.graph_metadata());
    let left_iota = broadcast(&left_iota, &[j, k]);
    let right_iota = iota(&[k], diagonal, &x.graph_metadata());
    let right_iota = broadcast_general(&right_iota, &[j, k], &[1]);
    let compare = ge(&left_iota, &right_iota);
    let compare = broadcast_general(&compare, &x.dims(), &[1, 2]);
    let zeros = Tensor::zeros(&x.shape(), &x.graph_metadata());
    select(&compare, x, &zeros)
}

When I load this .pb and run it on CPU via jax, it succeeds:

    # Compile the HLO to an XLA executable
    start = time.time()
    compile_options = xla_client.CompileOptions()
    compile_options.num_replicas = 1
    compile_options.num_partitions = 1
    xla_comp_mlir = xla_client._xla.mlir.xla_computation_to_mlir_module(xla_comp) # type: ignore
    executable = client.compile(xla_comp_mlir, compile_options)
    print(f"run_xla_cpu_gpu: Took {time.time() - start}s to compile the executble for {args.device}", file=sys.stderr)

    input_devices = []
    input_values = inputs["inputs"]
    for i, input in enumerate(input_values):
        input = jnp.array(input)
        input = jnp.reshape(input, inputs["input_shapes"][i])
        input_devices.append(jax.device_put(input, device))
    start = time.time()
    outputs = executable.execute(input_devices)
    print(f"run_xla_cpu_gpu: Took {time.time() - start}s to execute the executable on {args.device}", file=sys.stderr)

Successful output:

Rust CPU output: Ok([Float32([0.6898583, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.70216024, 0.96511054, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4294598, 0.049052358, 0.45411062, 0.0, 0.0, 0.0, 0.0, 0.7763959, 0.96817935, 0.5216516, 0.0021128654, 0.0, 0.0, 0.0, 0.30257988, 0.12236631, 0.9877379, 0.5611696, 0.96858346, 0.0, 0.0, 0.42054284, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8537997, 0.21233284, 0.0, 0.0, 0.0, 0.0, 0.0, 0.048771977, 0.4087988, 0.66929305, 0.0, 0.0, 0.0, 0.0, 0.88218296, 0.9953327, 0.3668443, 0.5316063, 0.0, 0.0, 0.0, 0.7229221, 0.6043123, 0.7960675, 0.6734197, 0.31135893, 0.0, 0.0, 0.785504, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.26309836, 0.28174615, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9644351, 0.70734894, 0.510087, 0.0, 0.0, 0.0, 0.0, 0.81259894, 0.035218716, 0.6208538, 0.24419558, 0.0, 0.0, 0.0, 0.73210454, 0.45967913, 0.0017743111, 0.13246703, 0.50136197, 0.0, 0.0])])
Python output: [Float32([0.6898583, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.70216024, 0.96511054, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4294598, 0.049052358, 0.45411062, 0.0, 0.0, 0.0, 0.0, 0.7763959, 0.96817935, 0.5216516, 0.0021128654, 0.0, 0.0, 0.0, 0.30257988, 0.12236631, 0.9877379, 0.5611696, 0.96858346, 0.0, 0.0, 0.42054284, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8537997, 0.21233284, 0.0, 0.0, 0.0, 0.0, 0.0, 0.048771977, 0.4087988, 0.66929305, 0.0, 0.0, 0.0, 0.0, 0.88218296, 0.9953327, 0.3668443, 0.5316063, 0.0, 0.0, 0.0, 0.7229221, 0.6043123, 0.7960675, 0.6734197, 0.31135893, 0.0, 0.0, 0.785504, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.26309836, 0.28174615, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9644351, 0.70734894, 0.510087, 0.0, 0.0, 0.0, 0.0, 0.81259894, 0.035218716, 0.6208538, 0.24419558, 0.0, 0.0, 0.0, 0.73210454, 0.45967913, 0.0017743111, 0.13246703, 0.50136197, 0.0, 0.0])]

successes:
    tensor::ops::tests::xla_tril

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 41 filtered out; finished in 1.48s

However, when I compile and run that same .pb for Trainium, neuronx-cc gives me a compilation internal error:

thread 'tensor::ops::tests::xla_tril' panicked at 'Script failed:
2023-04-25T20:53:56Z WARNING 1373790 [LayoutBottleneck]: Connected component _compare.6 has no matmult/reduce/batchnorm. Guessing layout. Considering putting on CPU.
2023-04-25T20:53:56Z ERROR 1373790 [Tensorizer]: Transformation error on operator: _select.10
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]: ***************************************************************
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:  An Internal Compiler Error has occurred
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]: ***************************************************************
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]: Error message:  Mask pattern only support 1 step multiplier
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]: Error class:    AssertionError
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]: Error location: Unknown
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]: Command line:   /home/ubuntu/.cache/pypoetry/virtualenvs/kholinar-ivrH9p2M-py3.8/bin/neuronx-cc compile /tmp/rust_hlo_tril.pb --framework XLA --target trn1 --model-type transformer --auto-cast none --output /tmp/tril.neff
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]: Internal details:
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/driver/CommandDriver.py", line 237, in neuronxcc.driver.CommandDriver.CommandDriver.run
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/driver/commands/CompileCommand.py", line 1047, in neuronxcc.driver.commands.CompileCommand.CompileCommand.run
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/driver/commands/CompileCommand.py", line 998, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/driver/commands/CompileCommand.py", line 1023, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/driver/commands/CompileCommand.py", line 1027, in neuronxcc.driver.commands.CompileCommand.CompileCommand.runPipeline
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/driver/Job.py", line 300, in neuronxcc.driver.Job.SingleInputJob.run
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/driver/Job.py", line 326, in neuronxcc.driver.Job.SingleInputJob.runOnState
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/driver/Pipeline.py", line 30, in neuronxcc.driver.Pipeline.Pipeline.runSingleInput
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/driver/Job.py", line 300, in neuronxcc.driver.Job.SingleInputJob.run
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/driver/Job.py", line 326, in neuronxcc.driver.Job.SingleInputJob.runOnState
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/driver/jobs/Frontend.py", line 343, in neuronxcc.driver.jobs.Frontend.Frontend.runSingleInput
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/driver/jobs/Frontend.py", line 148, in neuronxcc.driver.jobs.Frontend.Frontend.runXLAFrontend
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/Penguin.py", line 297, in neuronxcc.starfish.penguin.Penguin.runPenguin
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/Frontend.py", line 150, in neuronxcc.starfish.penguin.Frontend.tensorizeXla
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/Frontend.py", line 151, in neuronxcc.starfish.penguin.Frontend.tensorizeXla
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/Frontend.py", line 159, in neuronxcc.starfish.penguin.Frontend.tensorizeXla
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/Frontend.py", line 212, in neuronxcc.starfish.penguin.Frontend.tensorizeXlaFromFile
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/Compile.py", line 215, in neuronxcc.starfish.penguin.Compile.compile_module
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/Compile.py", line 217, in neuronxcc.starfish.penguin.Compile.compile_module
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/Compile.py", line 260, in neuronxcc.starfish.penguin.Compile.compile_module
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/Compile.py", line 266, in neuronxcc.starfish.penguin.Compile.genenerate_code_and_metadata_for_module
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/Compile.py", line 122, in neuronxcc.starfish.penguin.Compile.generate_code_and_metadata
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/Compile.py", line 352, in neuronxcc.starfish.penguin.Compile.codegen
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/Compile.py", line 358, in neuronxcc.starfish.penguin.Compile.codegenBIR
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/targets/tonga/codegen/BirCodeGenLoop.py", line 1705, in neuronxcc.starfish.penguin.targets.tonga.codegen.BirCodeGenLoop.BirCodeGenLoop.runOnFunction
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 196, in neuronxcc.starfish.penguin.DotTransform.DotTransform.run_with_exception_handling
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 191, in neuronxcc.starfish.penguin.DotTransform.DotTransform.run_with_exception_handling
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 208, in neuronxcc.starfish.penguin.DotTransform.DotTransform.timed_run_
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 210, in neuronxcc.starfish.penguin.DotTransform.DotTransform.timed_run_
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 211, in neuronxcc.starfish.penguin.DotTransform.DotTransform.timed_run_
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 240, in neuronxcc.starfish.penguin.DotTransform.DotTransform.run_
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 242, in neuronxcc.starfish.penguin.DotTransform.DotTransform.run_
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 342, in neuronxcc.starfish.penguin.DotTransform.DotTransform.transformFunction
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 343, in neuronxcc.starfish.penguin.DotTransform.DotTransform.transformFunction
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 333, in neuronxcc.starfish.penguin.DotTransform.DotTransform.runTransforms
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 322, in neuronxcc.starfish.penguin.DotTransform.DotTransform.transformStmts
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 134, in neuronxcc.starfish.penguin.DotTransform.DotTransform.transform
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 375, in neuronxcc.starfish.penguin.DotTransform.DotTransform.transformBasicBlock
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 378, in neuronxcc.starfish.penguin.DotTransform.DotTransform.transformBasicBlock
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/DotTransform.py", line 134, in neuronxcc.starfish.penguin.DotTransform.DotTransform.transform
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/targets/tonga/codegen/BirCodeGenLoop.py", line 1508, in neuronxcc.starfish.penguin.targets.tonga.codegen.BirCodeGenLoop.BirCodeGenLoop.transformInstruction
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/targets/tonga/codegen/BirCodeGenLoop.py", line 1281, in neuronxcc.starfish.penguin.targets.tonga.codegen.BirCodeGenLoop.BirCodeGenLoop.addInstToBir
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/targets/tonga/codegen/BirCodeGenLoop.py", line 1278, in neuronxcc.starfish.penguin.targets.tonga.codegen.BirCodeGenLoop.BirCodeGenLoop.dispatch_codegen
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   File "neuronxcc/starfish/penguin/targets/tonga/codegen/BirCodeGenLoop.py", line 306, in neuronxcc.starfish.penguin.targets.tonga.codegen.BirCodeGenLoop.BirCodeGenLoop.codegenAffSelTensorScalarOp
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]: Version information:
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   NeuronX Compiler version 2.5.0.28+1be23f232
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   HWM version 2.5.0.0-dad732dd6
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   NEFF version Dynamic
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   TVM not available
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   NumPy version 1.21.6
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:   MXNet not available
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]:
2023-04-25T20:53:56Z ERROR 1373790 [neuronx-cc]: Artifacts stored in: /home/ubuntu/dev/Kholinar/xla/neuronxcc-b1zz1m0b
, stdout:
', xla/src/xla_utils.rs:29:5
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

failures:
    tensor::ops::tests::xla_tril

test result: FAILED. 0 passed; 1 failed; 0 ignored; 0 measured; 41 filtered out; finished in 6.23s

I have many other .pb files successfully running on Trainium (softmax, gelu, rmsnorm, etc.), this is the first one I've encountered that causes an issue with the neuron compiler. I find that the compiler crashes even when tril does not depend on an f32 input tensor, but rather is something like tril(tensor.ones(x.shape)), an example XLA graph of that here: rust_hlo_tril.pb.zip

Please let me know if I am trying something unsupported here and how I ought to work around it. Thanks!

micwade-aws commented 1 year ago

Thanks for reaching out - we're taking a look!

micwade-aws commented 1 year ago

@xanderdunn - this is a bug in neuronx-cc related to the iota operator. We’re working on a fix and will keep this ticket updated as we make progress.

aws-donkrets commented 1 year ago

@xanderdunn - we believe your reported issue is fixed in the latest Neuron SDK (2.11). Please give it a try and update/close this ticket as appropriate.

xanderdunn commented 1 year ago

Confirmed, after upgrading to Neuron SDK 2.11 this compiled for me!

$ neuronx-cc compile /tmp/rust_hlo_tril.pb --framework XLA --target trn1 --model-type transformer --auto-cast none --output /tmp/tril.neff
2023-06-16T23:23:42Z WARNING 111032 [LayoutBottleneck]: Connected component _compare.6 has no matmult/reduce/batchnorm. Guessing layout. Considering putting on CPU.
Selecting 4 allocations
0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************
Analyzing dependencies of Block1
0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************
Analyzing dependencies of Block1
0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************
Dependency reduction of sg0000
0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************

Huge thanks for the bug fix!