Open vadimkantorov opened 1 year ago
Might also be nice to add opaque dtypes torch.bits1x512, torch.bits1x256, torch.bits1x128 - they should allow expressing custom layouts
If opaque is now created, maybe then would be easier to introduce these dtypes: https://github.com/pytorch/pytorch/pull/115044
(these dtypes are needed only for some basic things: like transposing/conversions and viewing-back-as-scalar-dtype and expressing the memory layouts which are optimized for vectorized/wide memory reads, but otherwise are done in col-major)
I think this is relevant to what we were trying to do in https://github.com/pytorch-labs/ao/pull/13, basically for introducing new dtypes or new type of Tensor, we'd like to start with using tensor subclassing by default, and can implement operations on the tensor subclass. The tensor subclass can be backed up by a bits8
Tensor which is uninterpreted by itself.
Could you describe what are the operations that's needed for this tensor? my current understanding is that bits8
is enough for everything since it's uninterpreted, all the semantic information about the Tensor (e.g. how it's packed, how to do transpose, view etc.) can be stored in the tensor subclass, or in the operator (e.g. some_operator_works_with_int4_tensor that takes the bits8
dtype as input).
Well, I just noticed that there exist memory layouts like COL32 and sth similar for MKL-DNN, which cannot be manipulated/converted-to-from/represented logically in PyTorch currently with strides. I think allowing for arbitrary or very wide dtypes can help representing all these layouts natively and remove hacks from those backends, and being able to express explicitly (and explicitly manually prepare the layout and avoid excessive automatic conversions) that these kernels (COL32/4R2 need to be used).
Regarding these, I think only various conversions and views and restrides/transposes are sufficient to be useful. Maybe some slow-ish item accessors for printing/debugging.
Regarding bits8, pack/unpack would be very useful: https://github.com/pytorch/ao/issues/292 + going in the future, it might be useful to introduce a BitTensor/BitMask as a compressed form of BoolTensor.
cc @ezyang
I suppose you could add a bits128 and bigger types. But I'm not really sure there's all that much payoff. In particular, instead of a float16x32 tensor of size (S), you can instead have a float16 tensor of size (S, 32). Sure, you have to dork around with the last dimension, but with a bits128 type you have to dork around with the dtype, so you're not really gaining anything. Just simulate it with a tensor subclass.
We have talked, in the past, about adding native support for blocked layout (cc @colesbury). But in the end we didn't, in part because NVIDIA told us that they didn't want us relying on the specifics of the memory layout (cc @jjsjann123). And while these are interesting, they're not that interesting... not enough to warrant actually putting in the dtype.
If the ask is for a slick interface where torch.empty(dtype=float16x32_blocked) works, then we probably should figure out how to do user defined dtypes. cc @bdhirsh
My idea was mostly to be able to express these blocked layouts to be able to bind nicely both col32 and mkl layout - to be able to ensure that these kernels specifically are called and that this layout is propagated/preserved (similar to FasterTransformers) and that no (or fewer) conversions are done in hidden way. I think this is mostly useful for various testing/experimentation with bindings (element-wise ops can be performed in the raw storage)
Also, if arbitrary wide custom styles are supported, it could be used for substituted of fixed-length string dtype, as in numpy: https://numpy.org/doc/stable/reference/arrays.dtypes.html
Btw, would just adding wide dtypes be sufficient to represent the blocked layout with only strides? (E.g. without a new "memory_format" - e.g. the col32 memory format can be deduced from wide dtype and certain strides)
Maybe adding these dtypes can allow upstreaming some FasterTransformer kernels for col32?
Yeah, I think the right way to do this is with a tensor subclass. Especially because even if you have a generic blocked dtype, it still doesn't tell you WHAT blocking you are using and you could potentially still mixup. And then there is just not that much interest in col33 (though I could be wrong jaja)
@ezyang you are right, without checking for strides (= deducing the actual "memory_format") one cannot call the col32 kernels, but it might be extended later. or if the funny memory_format kernels are available for manual calling, having tested functions for restriding/reordering/contig call should already be useful.
Am I right that wide dtypes + arbitrary strides are sufficient for representing memory format for onednn and cublasLt+col32? if so, then tensor subclass are not mandatory
whole FasterTransformers was done to preserve col32 and call col32 kernels I think :)
my main problem with tensor subclasses is that they make appear that they compose well with other subclasses and ops, while in fact it is most often not the case. For subclass wrappers over dense tensors, so even having some auto-rewrap peephole dispatch for elementwise ops or for somehow enabling auto-casting to dense tensor if some op is not supported. But overall, I am somehow not trusting very well proliferation of subclasses (especially if the underlying storage is a dense tensor - mainly for lacking intuition about composability with other aspects and multiple-aspect thing, e.g. quantized sparse tensor stored with col32) :) I also wonder if a weaker typing instrument, like attaching a tag dict to dense tensor objects could support some dispatch mechanisms.
Am I right that wide dtypes + arbitrary strides are sufficient for representing memory format for onednn and cublasLt+col32?
I mean, it depends on what you mean by "represent". Obviously without adding an actual dtype col32 you won't get any nominal distinction. But for a useful structural representation, I wouldn't bother with the wide dtype though, shove the vector count as the last stride and you're all done.
whole FasterTransformers was done to preserve col32 and call col32 kernels I think :)
Yes. But for anything fancy you need a col32 kernel to be able to work with col32 format, and for pointwise literally anything works lol.
my main problem with tensor subclasses is that they make appear that they compose well with other subclasses and ops, while in fact it is most often not the case. For subclass wrappers over dense tensors, so even having some auto-rewrap peephole dispatch for elementwise ops or for somehow enabling auto-casting to dense tensor if some op is not supported.
Yeah, the compositionality question is troublesome. Wrapper subclasses tend to compose better, but not all of our subclasses are implemented this way. That being said, auto-casting to dense is something that you CAN do, relatively easily, with a fallback torch dispatch. We even have a multiple dispatch support with NotImplemented so two subclasses that don't know about each other can still interoperate, if one of them is willing to desugar to dense.
I also wonder if a weaker typing instrument, like attaching a tag dict to dense tensor objects could support some dispatch mechanisms.
I actually do kind of think we should not have let people subclass Tensor, and instead have some non-OO mechanism that is able to let you inplace change the "class" of a tensor at runtime. The ship has really sailed on this one.
shove the vector count as the last stride and you're all done.
It did not occur to me :) In this case, I guess it would be great to have some examples of calling cublasLt with this int8-quantized + col32 format (and having all data prep/transposes in PyTorch). And maybe the OneDNN's nChw16c/nChw8c? https://www.speechmatics.com/company/articles-and-news/fast-and-accurate-gpu-quantization-for-transformers said that for quantized inputs col32 had perf boost - maybe not anymore in cuda12?
Also, I wonder if for TensorIterator or inner structures some more tuple types are useful, e.g. there already exist Vec256 and Vec512, right? I initially thought that introducing these tuple dtypes could help to pass the intent from the user to the dispatcher that these col32/onednn kernels should be called.
Is it right to consider these tupled memory formats as somehow midpoint between nchw and nhwc?
non-OO mechanism that is able to let you inplace change the "class" of a tensor at runtime. The ship has really sailed on this one.
I think, for low-difference from dense tensors, standardizing adding some tag
dict is still worth it (e.g. for all the Boxes/Polygons in torchvision). This is also helpful for modernizing existing code without changing function signatures working with a torch.Tensor and also for preseriving existing user understanding of what's going on.
For things really different from dense tensor, maybe subclasses are okay (as a means for dispatch), but then there should be doc pages per subclass listing all available ops to remove the illusion that all ops are supported and somehow add more of callthroughs (e.g. I think that TensorList is a good candidate for a tensor-like subclass, as it's mainly a means for dispatch, and it would be nice to call directly torch.div on a tensorlist instead of torch._foreach_div).
@ezyang Another reason for wide dtypes: some things (like higher order associative scan) are used for "pointwise" processing. If we want to process complex numbers, quaternions or mat4x4, it might be nice to have such a crutch for this (e.g. like in graphics vec4/vec3/matNxM/matN dtypes in GLSL)
Fwiw our current opinion is that wide dtype for complex was a mistake; instead should have had complex as a subclass conposing some underlying float tensor. This gets you all the float variants of complex, and it lets you choose to interleave or not (uninterleaved representation required for efficient matmul.)
Maybe... I would say that these tiled dtypes are useful for "representational" clarity / UX + for some dispatch. But overall I'd say that there is value in allowing to faithfully express/represent as many (useful / frequent in practice) layouts as possible (even if ops are supported only for some of them)
IMO best would be supporting tagged
typing (to avoid subclasses and wrap/unwrap clutters and type hierarchy explosion) and enabled tiled dtypes as some high-level construct allowing to represent actual semi-channels-last layouts for vectorized reads/writes + maybe for dispatch transparency. And tagged
typing info could e.g. support more sparse layouts or uninterleaved.
It's also useful as when the backends start to do layout tuning, it's important to at least report and represent the layouts that are being tried, so that the user could insist on some specific layout
Regarding complex wide dtype, I'd say that it's still important if it's used as legacy layout in other existing contexts
Maybe for complex usecase specifically, the wide dtype as the main/single option was maybe a mistake, but I think that representing logically various wide dtypes can be useful. E.g. if we have a bool8 dtype, then it's trivially packable into torch.bits :) These wide dtypes are quite common in shaders/computer graphics for representing xyz coordinates, color tuples etc
Ha-ha, in the new post on TK, Hazylabs is advocating for wider use of wide dtypes as a logic primitive: https://hazyresearch.stanford.edu/blog/2024-05-12-tk
But ThunderKittens has good abstractions -- small tiles -- that match where both AI and hardware are going. ThunderKittens doesn’t support any dimension less than 16. But in our view, this doesn’t really matter, since the hardware doesn’t particularly want to, either. And we ask: if your matrix multiply is smaller than 16x16, are you sure what you’re doing is AI? From a philosophical point of view, we think a frame shift is in order. A “register” certainly shouldn’t be a 32-bit word like on the CPUs of old. And a 1024-bit wide vector register, as CUDA uses, is certainly a step in the right direction. But to us a “register” is a 16x16 tile of data. We think AI wants this -- after all this time, it’s still just matrix multiplies, reductions, and reshapes. And we think the hardware wants this, too -- small matrix multiplies are just begging for hardware support beyond just the systolic mma.
🚀 The feature, motivation and pitch
I found a blog post explaining how to get speedups by using int8 gemm kernels on CUDA: https://www.speechmatics.com/company/articles-and-news/fast-and-accurate-gpu-quantization-for-transformers
It mentions several specialized memory layouts to maximize perf of cublas gemm ops (found them documented in https://docs.nvidia.com/cuda/cublas/#cublasltorder-t):
CUBLASLT_ORDER_COL32
(mentioned as the most performant memory layout for cublasLt int8 gemms)CUBLASLT_ORDER_COL4_4R2_8C
CUBLASLT_ORDER_COL32_2R_4R4
These memory formats/layouts seem important for max-perf cublasLt for int8, as evidenced by this blog post and FasterTransformer kernels. @vkuzo are these formats supported by
_int_mm
in https://github.com/pytorch/pytorch/pull/96685?Is
CUBLASLT_ORDER_COL32
logically representable in PyTorch? Other formats?For int8, I would understand this means as if there was dtype holding 32-sized tiles/tuples of
torch.int8
(32 bytes = 256 bits), currently the widest dtype iscomplex128
(supposed to hold a tuple of two float64 i.e. 16bytes = 128 bits). (and then these tuples are stored in column-major)So IMO a minimal way to support it would be:
torch.int8x32
, maybetorch.float16x32
and so forthtorch.bits2x4
(there also exists torch.quint2x4) and
torch.bits16, so maybe at least
torch.bits2048` could be added?memory_format=
supporting conversions to these layouts via this avenue (especially forCUBLASLT_ORDER_COL32_2R_4R4
?)In that blog post they are fusing the quantization + conversion into a LayerNorm kernel, but I guess this can be introduced later (if needed at all).
Also, FasterTransformer has many COL32 kernels: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/unfused_attention_int8_kernels.cu
It also appears that similar tiled/blocked memory formats dtypes are used for MKLDNN: https://oneapi-src.github.io/oneDNN/v1.0/dev_guide_understanding_memory_formats.html and are probably already supported by calling
.to_mkldnn()
/ weight prepacking. I wonder if supporting such tupled-dtypes could be good and unify the layouts for cublasLt and MKLDNN (cc @jamesr66a @mingfeima)For onednn it is:
@jerryzh168 @vkuzo @ngimel this is probably also related to int8 gemm code paths as we discussed in https://github.com/pytorch/pytorch/issues/69364
(256 bit AVX2's equivalent is float32x8, and for AVX512 it's float32x16