Open mars1248 opened 1 week ago
Thanks for your report, perhaps it's better to start by describing the problem you'd like to solve?
Thanks for your report, perhaps it's better to start by describing the problem you'd like to solve?
Thank you for your answer. Last year, I implemented resize bicubic, a custom call. To be used by torch xla's upsampler_bicubic. However, the previous implementation of custom call always passed StridedMemrefView as a parameter, but now it seems to be BufferAllocation::Slice, so we can't get the Strided information of the tensor directly
Ah interesting. I don't think XLA ever supported StridedMemRef (or maybe it pretended to support, but actually ignored the stride). We've just discussed adding support to strided views as a possible feature. CC @ezhulenev @bchetioui
You can use XLA:FFI, see example in jax: https://github.com/google/jax/blob/main/docs/cuda_custom_call/foo.cu.cc#L71. You can get access to buffer dimensions via xla::ffi::Buffer
argument, however it is not strided in a sense that it is always a dense buffer (XLA never supported truly strided buffers).
Instead of removed custom calls APIs you can use XLA FFI: https://github.com/openxla/xla/tree/main/xla/ffi - it is very similar in spirit, and APIs look very similar as well, only internal implementation detail is different.
XLA never supported truly strided buffers
Should we support them? In a sense that custom call gets a "stride" argument? And XLA generates such "strided" arguments after e.g. slice?
Yes, that is something not too hard to add on top of dynamic slice fusion, however I guess most of the custom calls are not prepared to handle strides, but this can be done with explicit opt-in.
Yeah exactly, there should be a way to register such custom calls.
Last year I followed the implementation logic in this file, https://github.com/openxla/xla/blob/7954169ccfb6290d94af3ea3634229b097682ba8/xla/service/gpu/runtime/gemm.cc and the input parameters were defined as StridedMemrefView, but I see that the master code has almost no use for this interface, changed to BufferAllocation::Slice, but my logic needs a raw view, To get the shape of each dimension, what interface should I use? and I found use thunk impl custom call, why is this?