aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
465 stars 154 forks source link

Error Benchmarking Automatically Generate NKI #1015

Closed nandeeka closed 1 month ago

nandeeka commented 1 month ago

I am trying to benchmark NKI code generated from the corresponding PyTorch code. My workflow is as follows:

  1. I run the PyTorch kernel and extract the .pb file
  2. I compile the NKI kernel using:
    neuronx-cc compile --framework XLA <.pb> --target trn1  --verbose info --pipeline compile SaveTemps --tensorizer-options='--print-nki --nki-dl'
  3. I append the following lines, where <inputs/outputs> are the same as the parameters to nki.simulate_kernel in the automatically generated kernel:
    func = nki.benchmark()(sg0000)
    func(
    <inputs/outputs>
    )
  4. I run the new file.

I think the first question is, is there a better way to get this fine-grain benchmarking information?

Sometimes (e.g., for matrix multiplication), this flow works great. But other times, I see the error:

Traceback (most recent call last):
  File "/home/ubuntu/nki-kernels/out/nki.py", line 153, in <module>
    func(
  File "neuronxcc/starfish/penguin/targets/nki/TraceKernel.py", line 808, in neuronxcc.starfish.penguin.targets.nki.TraceKernel.Kernel.__call__
  File "neuronxcc/starfish/penguin/targets/nki/TraceKernel.py", line 1088, in neuronxcc.starfish.penguin.targets.nki.TraceKernel.BaremetalKernel.post_process_call
  File "neuronxcc/starfish/penguin/targets/nki/TraceKernel.py", line 1091, in neuronxcc.starfish.penguin.targets.nki.TraceKernel.BaremetalKernel.post_process_call
  File "neuronxcc/starfish/penguin/targets/nki/TraceKernel.py", line 1158, in neuronxcc.starfish.penguin.targets.nki.TraceKernel.BaremetalKernel._compile
  File "neuronxcc/starfish/penguin/targets/nki/TraceKernel.py", line 57, in neuronxcc.starfish.penguin.targets.nki.TraceKernel.write_tensorizer_ir
  File "neuronxcc/starfish/penguin/targets/nki/TraceKernel.py", line 58, in neuronxcc.starfish.penguin.targets.nki.TraceKernel.write_tensorizer_ir
  File "neuronxcc/starfish/penguin/ir/IRWriter.py", line 58, in neuronxcc.starfish.penguin.ir.IRWriter.IRWriter.run
  File "neuronxcc/starfish/penguin/ir/SerializerBase.py", line 239, in neuronxcc.starfish.penguin.ir.SerializerBase.SerializerBase.serialize_dispatch
  File "neuronxcc/starfish/penguin/ir/SerializerBase.py", line 269, in neuronxcc.starfish.penguin.ir.SerializerBase.SerializerBase.serialize_func
  File "neuronxcc/starfish/penguin/ir/SerializerBase.py", line 241, in neuronxcc.starfish.penguin.ir.SerializerBase.SerializerBase.serialize_dispatch
  File "neuronxcc/starfish/penguin/ir/SerializerBase.py", line 278, in neuronxcc.starfish.penguin.ir.SerializerBase.SerializerBase.serialize_block
  File "neuronxcc/starfish/penguin/ir/SerializerBase.py", line 250, in neuronxcc.starfish.penguin.ir.SerializerBase.SerializerBase.serialize_dispatch
  File "neuronxcc/starfish/penguin/targets/tonga/TongaISAInst.py", line 3903, in neuronxcc.starfish.penguin.targets.tonga.TongaISAInst.TransposeOp.serialize
  File "neuronxcc/starfish/penguin/ir/IRWriter.py", line 231, in neuronxcc.starfish.penguin.ir.IRWriter.IRWriter.engine
  File "neuronxcc/starfish/penguin/ir/IRWriter.py", line 123, in neuronxcc.starfish.penguin.ir.IRWriter.IRWriter.module
AttributeError: 'int' object has no attribute '__module__'. Did you mean: '__mod__'?

Original Pytorch code is:

import os
import sys

import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as ni
import torch
from torch_neuronx import nki_jit
from torch_xla.core import xla_model as xm

def unmerged_lora(X, W0, A, SB, R):
  delta_W = torch.matmul(SB, torch.matmul(A, X))
  return torch.matmul(W0, X) + delta_W

def main():
  K = 4096
  M = 4096
  N = int(sys.argv[1])

  # Parameter from Idefics2
  # Source: https://colab.research.google.com/drive/1NtcTgRbSBKN7pYD3Vdx1j9m8pt3fhFDB?usp=sharing#scrollTo=SMujNa2vKbZd
  R = 8

  device = xm.xla_device()
  cpu = torch.device('cpu')

  W0 = torch.randn(M, K, dtype=torch.float16).to(device)
  X = torch.randn(K, N, dtype=torch.float16).to(device)
  A = torch.randn(R, K, dtype=torch.float16).to(device)
  SB = torch.randn(M, R, dtype=torch.float16).to(device)

  H = unmerged_lora(X, W0, A, SB, R)

  H = H.to(device=cpu)

  print(H)

if __name__ == "__main__":
  os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
  os.environ["NEURON_CC_FLAGS"]= " --disable-internal-io-dge --tensorizer-options='--print-nki --nki-dl'"

  main()

This creates the following NKI code:

import numpy as np
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.typing as nt
import neuronxcc.nki.isa as nisa
from neuronxcc.nki import trace
from neuronxcc.nki.language import par_dim

@trace
def sg0000(
  v1,
  v2,
  v3,
  v4,
  v5,
):
  import numpy as np
  import neuronxcc.nki as nki
  import neuronxcc.nki.language as nl
  import neuronxcc.nki.typing as nt
  import neuronxcc.nki.isa as nisa
  from neuronxcc.nki import trace
  from neuronxcc.nki.language import par_dim

  v1 = v1
  v2 = v2
  v3 = v3
  v4 = v4
  v5 = v5

  v6 = nl.shared_constant(np.identity(128, dtype=np.float16))
  v7 = nl.ndarray((nl.par_dim(128), 128), dtype=np.float16, name="identity_local_120", buffer=nl.sbuf)
  v8 = nl.ndarray((2, nl.par_dim(8), 2048), dtype=np.float16, name="40.111", buffer=nl.sbuf)
  v9 = nl.zeros((2, 16, nl.par_dim(128), 8), dtype=np.float32, name="40.115", buffer=nl.psum, lazy_initialization=True)
  v10 = nl.ndarray((2, nl.par_dim(128), 16, 8), dtype=np.float16, name="input1_pftranspose_40", buffer=nl.sbuf)
  v11 = nl.ndarray((nl.par_dim(32), 128), dtype=np.float16, name="dot.1.166", buffer=nl.sbuf)
  v12 = nl.zeros((nl.par_dim(128), 32), dtype=np.float32, name="dot.1.156", buffer=nl.psum, lazy_initialization=True)
  v13 = nl.ndarray((nl.par_dim(128), 2, 16), dtype=np.float16, name="", buffer=nl.sbuf)
  v14 = nl.zeros((nl.par_dim(1), 8), dtype=np.float32, name="", buffer=nl.psum, lazy_initialization=True)
  v15 = nl.ndarray((nl.par_dim(8), 1), dtype=np.float16, name="dot.2", buffer=nl.sbuf)
  v16 = nl.ndarray((nl.par_dim(32), 128), dtype=np.float16, name="dot.169", buffer=nl.sbuf)
  v17 = nl.zeros((nl.par_dim(128), 32), dtype=np.float32, name="dot.160", buffer=nl.psum, lazy_initialization=True)
  v18 = nl.ndarray((nl.par_dim(128), 2, 16), dtype=np.float16, name="", buffer=nl.sbuf)
  v19 = nl.ndarray((2, 16, nl.par_dim(128), 8), dtype=np.float16, name="44.171", buffer=nl.sbuf)
  v20 = nl.zeros((2, 16, nl.par_dim(8), 128), dtype=np.float32, name="44.140", buffer=nl.psum, lazy_initialization=True)
  v21 = nl.ndarray((2, nl.par_dim(8), 2048), dtype=np.float16, name="input2_pftranspose_44", buffer=nl.sbuf)
  v22 = nl.zeros((2, 4, nl.par_dim(1), 512), dtype=np.float32, name="dot.2.164", buffer=nl.psum, lazy_initialization=True)
  v23 = nl.ndarray((2, 4, nl.par_dim(1), 512), dtype=np.float16, name="", buffer=nl.sbuf)
  v24 = nl.ndarray((2, 4, 2, 4, nl.par_dim(128), 2048), dtype=np.float16, name="48.145", buffer=nl.sbuf)
  v25 = nl.zeros((2, 4, 2, 4, 16, nl.par_dim(128), 128), dtype=np.float32, name="48.150", buffer=nl.psum, lazy_initialization=True)
  v26 = nl.ndarray((2, 2, 4, nl.par_dim(128), 16, 512), dtype=np.float16, name="input3_pftranspose_48", buffer=nl.sbuf)
  v27 = nl.zeros((2, 4, nl.par_dim(1), 512), dtype=np.float32, name="", buffer=nl.psum, lazy_initialization=True)
  v28 = nl.ndarray((2, nl.par_dim(1), 2048), dtype=np.float16, name="", buffer=nl.sbuf)
  v29 = nl.ndarray((nl.par_dim(1), 8), dtype=np.float16, name="36.173", buffer=nl.sbuf)

  def BB_entry_1():
    """ tensor_op_name: input1_pftranspose_40 | hlo_id: 21 |  """
    v7[nl.arange(128)[:, None], nl.arange(128)[None, :]] = nl.load(v6[nl.arange(128)[:, None], nl.arange(128)[None, :]], dtype=np.float16, mask=None)

    for i0 in nl.affine_range(2):
      """ tensor_op_name: input1_pftranspose_40 | hlo_id: 21 |  """
      v8[i0, nl.arange(8)[:, None, None], nl.arange(128)[None, None, :]+128*nl.arange(16)[None, :, None]] = nl.load(v2[nl.arange(8)[:, None, None], 16*i0+nl.arange(16)[None, :, None], nl.arange(128)[None, None, :]], dtype=np.float16, mask=None)

      for i1 in nl.affine_range(16):
        """ tensor_op_name: input1_pftranspose_40 | hlo_id: 21 |  """
        v9[i0, i1, nl.arange(128)[:, None], nl.arange(8)[None, :]] = nisa.nc_matmul(v8[i0, nl.arange(8)[:, None], nl.arange(128)[None, :]+128*i1], v7[nl.arange(8)[:, None], nl.arange(8)[None, :]], is_stationary_onezero=False, is_moving_onezero=True, mask=None, is_transpose=True)
        """ tensor_op_name: input1_pftranspose_40 | hlo_id: 21 |  """
        v10[i0, nl.arange(128)[:, None], i1, nl.arange(8)[None, :]] = nl.copy(v9[i0, i1, nl.arange(128)[:, None], nl.arange(8)[None, :]], dtype=np.float16, mask=None)
        """ end loop i1 """
      """ end loop i0 """
    """ tensor_op_name: _dot.1 | hlo_id: 21 |  """
    v11[nl.arange(32)[:, None], nl.arange(128)[None, :]] = nl.load(v1[nl.arange(32)[:, None], nl.arange(128)[None, :]], dtype=np.float16, mask=None)
    """ tensor_op_name: _dot.1 | hlo_id: 21 |  """
    v12[nl.arange(128)[:, None, None], 16*nl.arange(2)[None, :, None]+nl.arange(16)[None, None, :]] = nisa.nc_matmul(v11[nl.arange(32)[:, None], nl.arange(128)[None, :]], v7[nl.arange(32)[:, None, None], 16*nl.arange(2)[None, :, None]+nl.arange(16)[None, None, :]], is_stationary_onezero=False, is_moving_onezero=True, mask=None, is_transpose=True)
    """ tensor_op_name: _dot.1 | hlo_id: 21 |  """
    v13[nl.arange(128)[:, None, None], nl.arange(2)[None, :, None], nl.arange(16)[None, None, :]] = nl.copy(v12[nl.arange(128)[:, None, None], 16*nl.arange(2)[None, :, None]+nl.arange(16)[None, None, :]], dtype=np.float16, mask=None)

    for i2 in nl.affine_range(2):
      for i3 in nl.affine_range(16):
        """ tensor_op_name: _dot.1 | hlo_id: 21 |  """
        v14[0, nl.arange(8)[None, :]] += nisa.nc_matmul(v13[nl.arange(128)[:, None], i2, i3], v10[i2, nl.arange(128)[:, None], i3, nl.arange(8)[None, :]], is_stationary_onezero=False, is_moving_onezero=False, mask=None)
        """ end loop i3 """
      """ end loop i2 """
    """ tensor_op_name: dot.2_pftranspose_36 | hlo_id: 21 |  """
    v29[0, nl.arange(8)[None, :]] = nl.copy(v14[0, nl.arange(8)[None, :]], dtype=np.float16, mask=None)
    """ tensor_op_name: dot.2_pftranspose_36 | hlo_id: 21 |  """
    v15[nl.arange(8)[:, None], 0] = nisa.nc_transpose(v29[0, nl.arange(8)[None, :]], dtype=np.float16, mask=None, engine=0)
    """ tensor_op_name: _dot | hlo_id: 18 |  """
    v16[nl.arange(32)[:, None], nl.arange(128)[None, :]] = nl.load(v1[nl.arange(32)[:, None], nl.arange(128)[None, :]], dtype=np.float16, mask=None)
    """ tensor_op_name: _dot | hlo_id: 18 |  """
    v17[nl.arange(128)[:, None, None], 16*nl.arange(2)[None, :, None]+nl.arange(16)[None, None, :]] = nisa.nc_matmul(v16[nl.arange(32)[:, None], nl.arange(128)[None, :]], v7[nl.arange(32)[:, None, None], 16*nl.arange(2)[None, :, None]+nl.arange(16)[None, None, :]], is_stationary_onezero=False, is_moving_onezero=True, mask=None, is_transpose=True)
    """ tensor_op_name: _dot | hlo_id: 18 |  """
    v18[nl.arange(128)[:, None, None], nl.arange(2)[None, :, None], nl.arange(16)[None, None, :]] = nl.copy(v17[nl.arange(128)[:, None, None], 16*nl.arange(2)[None, :, None]+nl.arange(16)[None, None, :]], dtype=np.float16, mask=None)

    for i4 in nl.affine_range(2):
      for i5 in nl.affine_range(16):
        """ tensor_op_name: input2_pftranspose_44 | hlo_id: 24 |  """
        v19[i4, i5, nl.arange(128)[:, None], nl.arange(8)[None, :]] = nl.load(v3[2048*i4+128*i5+nl.arange(128)[:, None], nl.arange(8)[None, :]], dtype=np.float16, mask=None)
        """ tensor_op_name: input2_pftranspose_44 | hlo_id: 24 |  """
        v20[i4, i5, nl.arange(8)[:, None], nl.arange(128)[None, :]] = nisa.nc_matmul(v19[i4, i5, nl.arange(128)[:, None], nl.arange(8)[None, :]], v7[nl.arange(128)[:, None], nl.arange(128)[None, :]], is_stationary_onezero=False, is_moving_onezero=True, mask=None, is_transpose=True)
        """ tensor_op_name: input2_pftranspose_44 | hlo_id: 24 |  """
        v21[i4, nl.arange(8)[:, None], nl.arange(128)[None, :]+128*i5] = nl.copy(v20[i4, i5, nl.arange(8)[:, None], nl.arange(128)[None, :]], dtype=np.float16, mask=None)
        """ end loop i5 """

      for i6 in nl.affine_range(4):
        """ tensor_op_name: _dot.2 | hlo_id: 24 |  """
        v22[i4, i6, 0, nl.arange(512)[None, :]] = nisa.nc_matmul(v15[nl.arange(8)[:, None], 0], v21[i4, nl.arange(8)[:, None], 512*i6+nl.arange(512)[None, :]], is_stationary_onezero=False, is_moving_onezero=False, mask=None)
        """ tensor_op_name: _dot.2 | hlo_id: 24 |  """
        v23[i4, i6, 0, nl.arange(512)[None, :]] = nl.copy(v22[i4, i6, 0, nl.arange(512)[None, :]], dtype=np.float16, mask=None)

        for i7 in nl.affine_range(2):
          for i8 in nl.affine_range(4):
            """ tensor_op_name: input3_pftranspose_48 | hlo_id: 18 |  """
            v24[i4, i6, i7, i8, nl.arange(128)[:, None, None], 128*nl.arange(16)[None, :, None]+nl.arange(128)[None, None, :]] = nl.load(v4[16*i4+4*i6+i8, nl.arange(128)[:, None, None], 16*i7+nl.arange(16)[None, :, None], nl.arange(128)[None, None, :]], dtype=np.float16, mask=None)

            for i9 in nl.affine_range(16):
              """ tensor_op_name: input3_pftranspose_48 | hlo_id: 18 |  """
              v25[i4, i6, i7, i8, i9, nl.arange(128)[:, None], nl.arange(128)[None, :]] = nisa.nc_matmul(v24[i4, i6, i7, i8, nl.arange(128)[:, None], nl.arange(128)[None, :]+128*i9], v7[nl.arange(128)[:, None], nl.arange(128)[None, :]], is_stationary_onezero=False, is_moving_onezero=True, mask=None, is_transpose=True)
              """ tensor_op_name: input3_pftranspose_48 | hlo_id: 18 |  """
              v26[i7, i4, i6, nl.arange(128)[:, None], i9, 128*i8+nl.arange(128)[None, :]] = nl.copy(v25[i4, i6, i7, i8, i9, nl.arange(128)[:, None], nl.arange(128)[None, :]], dtype=np.float16, mask=None)
              """ end loop i9 """
            """ end loop i8 """
          """ end loop i7 """

        for i10 in nl.affine_range(2):
          for i11 in nl.affine_range(16):
            """ tensor_op_name: _dot | hlo_id: 18 |  """
            v27[i4, i6, 0, nl.arange(512)[None, :]] += nisa.nc_matmul(v18[nl.arange(128)[:, None], i10, i11], v26[i10, i4, i6, nl.arange(128)[:, None], i11, nl.arange(512)[None, :]], is_stationary_onezero=False, is_moving_onezero=False, mask=None)
            """ end loop i11 """
          """ end loop i10 """
        """ tensor_op_name: _add.0 | hlo_id: 36 |  """
        v28[i4, 0, 512*i6+nl.arange(512)[None, :]] = nl.add(v27[i4, i6, 0, nl.arange(512)[None, :]], v23[i4, i6, 0, nl.arange(512)[None, :]], mask=None, dtype=np.float16)
        """ end loop i6 """
      """ tensor_op_name: _add.0 | hlo_id: 36 |  """
      nl.store(v5[2048*i4+nl.arange(2048)[None, :]], value=v28[i4, 0, nl.arange(2048)[None, :]], mask=None)
      """ end loop i4 """

  BB_entry_1()

cu = sg0000.specialize(
  nt.tensor[(32, 128), np.float16], # i=0
  nt.tensor[(8, 32, 128), np.float16], # i=1
  nt.tensor[(4096, 8), np.float16], # i=2
  nt.tensor[(32, 128, 32, 128), np.float16], # i=3
  nt.tensor[(4096,), np.float16], # i=4
)
print(cu)
ir = cu

# nki.simulate_kernel(sg0000, 
  # np.ndarray(shape=(32, 128), dtype=np.float16), # i=0
  # np.ndarray(shape=(8, 32, 128), dtype=np.float16), # i=1
  # np.ndarray(shape=(4096, 8), dtype=np.float16), # i=2
  # np.ndarray(shape=(32, 128, 32, 128), dtype=np.float16), # i=3
  # np.ndarray(shape=(4096,), dtype=np.float16), # i=4
# )

To benchmark this kernel, I add the following to the bottom:

func = nki.benchmark()(sg0000)
func(
  np.ndarray(shape=(32, 128), dtype=np.float16), # i=0
  np.ndarray(shape=(8, 32, 128), dtype=np.float16), # i=1
  np.ndarray(shape=(4096, 8), dtype=np.float16), # i=2
  np.ndarray(shape=(32, 128, 32, 128), dtype=np.float16), # i=3
  np.ndarray(shape=(4096,), dtype=np.float16), # i=4
)

Environment: I started with the Neuron 2.20 DLAMI and installed the Allocation API using the .deb and .whl files @aws-serina-tan sent me.

aws-serina-tan commented 1 month ago

This particular issue is because NKI-codegen (currently an experiment feature in NKI) doesn't print engine parameter correctly. We are opening a ticket to track this internally. Will post here once we have a fix. A temporary workaround can be removing engine=0 in the generated NKI file.

If the goal here is to benchmark a Pytorch model, do the tools in Performance and Benchmark Tools (https://awsdocs-neuron.readthedocs-hosted.com/en/latest/tools/index.html) help?

nandeeka commented 1 month ago

This worked. Thank you so much!