aws-neuron / nki-samples

MIT No Attribution
6 stars 3 forks source link

Running Custom Operators for PyTorch with NKI #25

Open nandeeka opened 1 week ago

nandeeka commented 1 week ago

I was curious about how the torch.topk() function is implemented in the Neuron compiler. After writing a kernel with this call and compiling it to NKI, the resulting code looks something like:

    """ tensor_op_name: xla___op_TopKImpl_custom-call.1 | hlo_id: 49 |  """
    v30[nl.arange(124)[:, None], nl.arange(16)[None, :]] = nl.load(v14.reshape([1984])[nl.arange(16)[None, :]+16*nl.arange(124)[:, None]], dtype=np.float16, mask=None)
    """ tensor_op_name: xla___op_TopKImpl_custom-call.1 | hlo_id: 49 |  """
    # NkiCodegen.codegenSundaMax8 is not implemented
    """ tensor_op_name: xla___op_TopKImpl_custom-call.1 | hlo_id: 49 |  """
    # NkiCodegen.codegenSundaMaxIndex8 is not implemented
    """ tensor_op_name: xla___op_TopKImpl_custom-call.1 | hlo_id: 49 |  """
    nl.store(v17[4*nl.arange(124)[:, None]+nl.arange(4)[None, :]], value=v15[0, nl.arange(124)[:, None], nl.arange(4)[None, :]], mask=None)
    """ tensor_op_name: xla___op_TopKImpl_custom-call.1 | hlo_id: 49 |  """
    nl.store(v18.reshape([496])[4*nl.arange(124)[:, None]+nl.arange(4)[None, :]], value=v16[0, nl.arange(124)[:, None], nl.arange(4)[None, :]], mask=None)

I understand that the NkiCodegen.codegenSundaMax8 and NkiCodegen.codegenSundaMaxIndex8 functions need to be black boxes because they require data-dependent control flow, which is currently not supported. My question is, is it possible for me to call these functions within my own NKI kernel?

Thanks!

JonathanHenson commented 1 week ago

Thanks for filing the issue! We're looking into it actively and will get back with an answer for you shortly.

aws-qieqingy commented 1 week ago

Hi Nandeeka! This requires us to implement the support for SundaMax8 in NKI, which we have added to our backlog of active tasks. We'll let you know when it's available.

AWSNB commented 1 week ago

@aws-qieqingy is there anything @nandeeka can do to bypass this limitation for short term ?

JonathanHenson commented 1 week ago

Also a note: NKICodegen is currently experimental and will come across unimplemented nki instructions until we've completed expressing each instruction in nki isa. Whenever this class of error occurs, it is most likely for this reason.

aws-qieqingy commented 1 week ago

@AWSNB Unfortunately, we don't have any instruction that achieves the same functionality exposed through NKI right now.