artyom-beilis / pytorch_dlprim

DLPrimitives/OpenCL out of tree backend for pytorch
http://blog.dlprimitives.org/
MIT License
264 stars 17 forks source link

Some misc. unimplemented features #22

Open nonnull-ca opened 1 year ago

nonnull-ca commented 1 year ago

I spent a chunk of yesterday seeing how far I could go with getting this running. I ran into a bunch of different issues of varying severity - I don't know if you would prefer them all in one issue or split.

Trivial issues:

  1. aten::rsub unimplemented.
  2. aten::eq unimplemented.

(I'll probably do up a PR for these two later today.)

Simple issues:

  1. Arithmetic between a tensor and a scalar (e.g. 1.0 - someTensor).
    1. To get things working I ended up needing to change the caller to torch.tensor(1.0) - someTensor for now.
    2. It's possible this is just fallout of aten::rsub being unimplemented.
  2. aten::any_all_out unimplemented.
    1. This would be fairly straightforward to implement, except I can't actually find the precise semantics anywhere.
    2. I took a stab at it anyway and it started failing elsewhere, so... progress?

Complex issues:

  1. No support for aten::native_batch_norm without mean & variance. (This assert. Note that this is actually being hit when calling layer_norm.)
  2. No aten::index_select support.
    1. ...and dlprim appears to be mostly either various kernels, none of which appear to match, or elementwise operations or reductions.
    2. If you could point me at a bunch of testcases here, and recommend an approach, I could probably take a stab at it.
  3. No half-precision support.
    1. I took a stab at plumbing this in... but then I hit the wall of 'all of the backend kernels for dlprim appear to be hardcoded for single-precision input vectors'.
    2. In particular, native_batch_norm, but I suspect there's more than just that.
  4. Not your fault, but there's a surprising number of callerside changes that need to happen when using the private device backend.
    1. Even once aten is fully implemented this is nowhere near a dropin replacement backend.
    2. Lots of code calling torch.cuda.device_count() / torch.cuda.device / etc.
    3. Lots of code assuming integer device indexes work, when you need 'privateusedevice:0' instead.

...and that's where I ran out of time.

Is there documentation of the aten API anywhere?

artyom-beilis commented 1 year ago

Trivial issues:

  • aten::rsub unimplemented.
  • aten::eq unimplemented.

Take a look into src/pointwise_ops.cpp many stuff implemented there and some are trivial to do

No support for aten::native_batch_norm

Actually it is usually called from norm_layer and it is on its way for transformers, see https://github.com/artyom-beilis/pytorch_dlprim/discussions/16

No aten::index_select support

It looks like something that requires custom but simple GPU kernel... it is probably more complex in backward. If you can give me a simple example (mnist style) of when/how it is used in real world I can look into it.

No half-precision support

Biggest issue there is gemm and winograd kernels - writing their optimised versions is going to be hard. Without it half-precision isn't that useful beside storage. bfloat16 is similar... Currently I just don't have capacity to handle it. I think it is somewhat lower priority since most can be done just using regular float

Not your fault, but there's a surprising number of callerside changes that need to happen when using the private device backend.

Obviously :-) but these are least of the problems at this point

Is there documentation of the aten API anywhere?

The functions signatures that need to be implemented (of course not all of them) can be found in ATen/RegistrationDeclarations.h - you can find it in $VIRTUAL_ENV/lib/python3.*/site-packages/torch/include/ATen/RegistrationDeclarations.h

Other than that there some docs in the source code and most functions that appear in aten:: have their exact python copy and have public docs.