openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.42k stars 363 forks source link

How to use StridedMemrefView, or is there an alternative interface? #14194

Open mars1248 opened 1 week ago

mars1248 commented 1 week ago

Last year I followed the implementation logic in this file, https://github.com/openxla/xla/blob/7954169ccfb6290d94af3ea3634229b097682ba8/xla/service/gpu/runtime/gemm.cc and the input parameters were defined as StridedMemrefView, but I see that the master code has almost no use for this interface, changed to BufferAllocation::Slice, but my logic needs a raw view, To get the shape of each dimension, what interface should I use? and I found use thunk impl custom call, why is this?

cheshire commented 1 week ago

Thanks for your report, perhaps it's better to start by describing the problem you'd like to solve?

mars1248 commented 1 week ago

Thanks for your report, perhaps it's better to start by describing the problem you'd like to solve?

Thank you for your answer. Last year, I implemented resize bicubic, a custom call. To be used by torch xla's upsampler_bicubic. However, the previous implementation of custom call always passed StridedMemrefView as a parameter, but now it seems to be BufferAllocation::Slice, so we can't get the Strided information of the tensor directly

cheshire commented 1 week ago

Ah interesting. I don't think XLA ever supported StridedMemRef (or maybe it pretended to support, but actually ignored the stride). We've just discussed adding support to strided views as a possible feature. CC @ezhulenev @bchetioui

ezhulenev commented 1 week ago

You can use XLA:FFI, see example in jax: https://github.com/google/jax/blob/main/docs/cuda_custom_call/foo.cu.cc#L71. You can get access to buffer dimensions via xla::ffi::Buffer argument, however it is not strided in a sense that it is always a dense buffer (XLA never supported truly strided buffers).

Instead of removed custom calls APIs you can use XLA FFI: https://github.com/openxla/xla/tree/main/xla/ffi - it is very similar in spirit, and APIs look very similar as well, only internal implementation detail is different.

cheshire commented 1 week ago

XLA never supported truly strided buffers

Should we support them? In a sense that custom call gets a "stride" argument? And XLA generates such "strided" arguments after e.g. slice?

ezhulenev commented 1 week ago

Yes, that is something not too hard to add on top of dynamic slice fusion, however I guess most of the custom calls are not prepared to handle strides, but this can be done with explicit opt-in.

cheshire commented 1 week ago

Yeah exactly, there should be a way to register such custom calls.