aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
447 stars 149 forks source link

SBUF non-uniform utilization #999

Open zhdllwyc opened 6 days ago

zhdllwyc commented 6 days ago

We observe non-uniform SBUF utitlization in neuron-profile on both trn1.32xlarge and trn1.2xlarge. I follow the NKI tutorial, and try to launch matrix multiplication by breaking matrices into tile. In the source code, we load left and right hand side matrices into SBUF by breaking them into tiles: lhsT_tile and rhs_tile. In NKI architecture guide, Figure 28 shows that a tensor should span tensor.shape[0] number of SBUF partition. By clearly specifying the K-dimension (K=128) as the partition dimension (see source code below), we expect the lhsT_tile and rhs_tile occupy the SBUF uniformly. However, that is not what I observed after profiling. Here goes a snapshot of the SBUF partition utilization I observed:

Partition 0 usage: 0.8255411982536316
Partition 32 usage: 0.0599263533949852
Partition 64 usage: 0.0599263533949852
Partition 96 usage: 0.0599263533949852

Below is my source code:

import neuronxcc.nki.language as nl
import torch
from torch_neuronx import nki_jit
from torch_xla.core import xla_model as xm
import os

def nki_matmul_tiled_(lhsT, rhs, result):
  """NKI kernel to compute a matrix multiplication operation in a tiled manner

  Args:
      lhsT: an input tensor of shape [K,M], where both K and M are multiples for
        128.  It is the left-hand-side argument of the matrix multiplication,
        delivered transposed for optimal performance.
      rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
        is a multiple of 512.  It is the right-hand-side argument of the matrix
        multiplication.
      result: the resulting output tensor of shape [M,N]
  """

  K, M = lhsT.shape
  K_, N = rhs.shape
  assert K == K_, "lhsT and rhs must have the same contraction dimension"

  TILE_M = 128
  TILE_K = 128
  TILE_N = 512

  # Use affine_range to loop over tiles
  for m in nl.affine_range(M // TILE_M):
    for n in nl.affine_range(N // TILE_N):
      # Allocate a tensor in PSUM
      res_psum = nl.zeros((TILE_M, TILE_N), nl.float32, buffer=nl.psum)

      for k in nl.affine_range(K // TILE_K):
        # Declare the tiles on SBUF
        lhsT_tile = nl.ndarray((nl.par_dim(TILE_K), TILE_M), dtype=lhsT.dtype, buffer=nl.sbuf)
        rhs_tile = nl.ndarray((nl.par_dim(TILE_K), TILE_N), dtype=rhs.dtype, buffer=nl.sbuf)

        # Load tiles from lhsT and rhs
        lhsT_tile[...] = nl.load(lhsT[k * TILE_K:(k + 1) * TILE_K,
                                      m * TILE_M:(m + 1) * TILE_M])
        rhs_tile[...] = nl.load(rhs[k * TILE_K:(k + 1) * TILE_K,
                                    n * TILE_N:(n + 1) * TILE_N])

        # Accumulate partial-sums into PSUM
        res_psum += nl.matmul(lhsT_tile[...], rhs_tile[...], transpose_x=True)

      # Copy the result from PSUM back to SBUF, and cast to expected output data-type
      res_sb = nl.copy(res_psum, dtype=result.dtype)
      nl.store(result[m * TILE_M:(m + 1) * TILE_M, n * TILE_N:(n + 1) * TILE_N],
               value=res_sb)

if __name__ == "__main__":

  os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
  os.environ["NEURON_CC_FLAGS"]= " --disable-internal-io-dge  " 

  device = xm.xla_device()

  # Test the big workload with basic kernel
  lhs_big = torch.rand((4096, 4096), dtype=torch.bfloat16, device=device)
  rhs_big = torch.rand((4096, 8192), dtype=torch.bfloat16, device=device)
  output_big = torch.zeros((4096, 8192), dtype=torch.bfloat16, device=device)

  # Run NKI kernel
  nki_matmul_basic_jit = nki_jit(nki_matmul_tiled_)
  nki_matmul_basic_jit(lhs_big.T, rhs_big, output_big)

  print(output_big)

My pip freeze is:

absl-py==2.1.0
accelerate==0.34.2
anyio==4.6.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
attrs==24.2.0
aws-neuronx-runtime-discovery==2.9
awscli==1.34.25
babel==2.16.0
beautifulsoup4==4.12.3
bleach==6.1.0
boto3==1.35.25
botocore==1.35.25
cachetools==5.5.0
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.3.2
cloud-tpu-client==0.10
colorama==0.4.6
comm==0.2.2
debugpy==1.8.5
decorator==5.1.1
defusedxml==0.7.1
docutils==0.16
ec2-metadata==2.13.0
environment-kernels==1.2.0
exceptiongroup==1.2.2
executing==2.1.0
fastjsonschema==2.20.0
filelock==3.16.1
fqdn==1.5.1
fsspec==2024.9.0
google-api-core==1.34.1
google-api-python-client==1.8.0
google-auth==2.35.0
google-auth-httplib2==0.2.0
googleapis-common-protos==1.65.0
h11==0.14.0
httpcore==1.0.5
httplib2==0.22.0
httpx==0.27.2
huggingface-hub==0.25.1
idna==3.10
ipykernel==6.29.5
ipython==8.27.0
ipywidgets==8.1.5
islpy==2023.2.5
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.4
jmespath==1.0.1
json5==0.9.25
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter==1.1.1
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.14.2
jupyter_server_terminals==0.5.3
jupyterlab==4.2.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
libneuronxla==2.0.4115.0
lockfile==0.12.2
MarkupSafe==2.1.5
matplotlib-inline==0.1.7
mistune==3.0.2
ml-dtypes==0.2.0
mpmath==1.3.0
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==2.8.8
neuronx-cc==2.15.128.0+56dc5a86
neuronx-distributed==0.9.0
notebook==7.2.2
notebook_shim==0.2.4
numpy==1.25.2
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.6.68
nvidia-nvtx-cu12==12.1.105
oauth2client==4.1.3
overrides==7.7.0
packaging==24.1
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pgzip==0.3.5
pillow==10.4.0
platformdirs==4.3.6
prometheus_client==0.21.0
prompt_toolkit==3.0.47
protobuf==3.20.3
psutil==6.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pyasn1==0.6.1
pyasn1_modules==0.4.1
pycparser==2.22
Pygments==2.18.0
pyparsing==3.1.4
python-daemon==3.0.1
python-dateutil==2.9.0.post0
python-json-logger==2.0.7
PyYAML==6.0.2
pyzmq==26.2.0
referencing==0.35.1
regex==2024.9.11
requests==2.31.0
requests-unixsocket==0.3.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.20.0
rsa==4.7.2
s3transfer==0.10.2
safetensors==0.4.5
scipy==1.11.2
Send2Trash==1.8.3
six==1.16.0
sniffio==1.3.1
soupsieve==2.6
stack-data==0.6.3
sympy==1.13.3
terminado==0.18.1
tinycss2==1.3.0
tokenizers==0.19.1
tomli==2.0.1
torch==2.1.2
torch-neuronx==2.1.2.2.3.0
torch-xla==2.1.4
torchvision==0.16.2
tornado==6.4.1
tqdm==4.66.5
traitlets==5.14.3
transformers==4.44.2
transformers-neuronx==0.12.313
triton==2.1.0
types-python-dateutil==2.9.0.20240906
typing_extensions==4.12.2
uri-template==1.3.0
uritemplate==3.0.1
urllib3==2.2.3
wcwidth==0.2.13
webcolors==24.8.0
webencodings==0.5.1
websocket-client==1.8.0
wget==3.2
widgetsnbextension==4.0.13

My neuron-profile version is:

neuron-profile 2.19.0.0%kaena-tools/2.19@c48a122 built on 2024-08-02T17:21:14Z

When profiling, I output the profile result of the second iteration:

neuron-profile capture -n "$file" -s profile.ntff --profile-nth-exec=2
JonathanHenson commented 3 days ago

Thank you for the issue. We're looking into it and will get back with you as soon as we have something to share.

JonathanHenson commented 3 days ago

That still looks suspiciously like DGE. What are the results if you add --disable-internal-io-dge here:

  nki_matmul_basic_jit = nki_jit(nki_matmul_tiled_, additional_compile_opt="--disable-internal-io-dge")
  nki_matmul_basic_jit(lhs_big.T, rhs_big, output_big)
JonathanHenson commented 3 days ago

I compiled the kernel and looked at the IR output and do not see any additional tensors that would be causing this behavior. However, when I look at the neff I see a bunch of GPSIMD instructions when DGE is on, and not when it's off and as far as I can tell that's the most likely culprit.

aws-qieqingy commented 2 days ago

We took a closer look at the issue, and we suspect there is a bug with the profiler or compiler. We are investigating this and will let you know once we have a conclusion.

zhdllwyc commented 2 days ago

That still looks suspiciously like DGE. What are the results if you add --disable-internal-io-dge here:

  nki_matmul_basic_jit = nki_jit(nki_matmul_tiled_, additional_compile_opt="--disable-internal-io-dge")
  nki_matmul_basic_jit(lhs_big.T, rhs_big, output_big)

Hi, thank you for looking in to this issue. When doing this. I receive the following error message: __init__() got an unexpected keyword argument 'additional_compile_opt' I believe I do not have access to the compiler that recognizes this argument.