Open CHDev93 opened 1 year ago
See functions like div_rn, div_rz, etc.
Thanks! I have never come across this libdevice module before and it looks really useful.
For my particular use case I'm trying to do integer division on arrays (trying to compute a mask based on the row and column indices). Looks like ddiv_rd
is almost what I want except it expects the inputs to be doubles.
raise ValueError(f"input arg type does not match."
ValueError: input arg type does not match.Expect one of dict_keys([(triton.language.fp64, triton.language.fp64)]), got (triton.language.int32, triton.language.int32)
My work around for now is to do something like
def div_rd(a, b):
return (a - (b - 1)) // b
I do think having //
behave differently than numpy or python is a bit surprising in the first place. Negative indices come up pretty often for me when creating masks for my tilings
You can convert a
and b
to fp64 or fp32 first before applying the libdevice function.
Please let me be more specific. I shouldn't assume that something not in the doc is obvious.
Likewise, you can use a.to(tl.float32)
to convert the data type.
Thanks for the code snippet! Decided to go with (a - (b - 1)) // b
to avoid all the various casting. Would be nice if triton //
behaved like numpy but I appreciate it's a pretty severe breaking change. Maybe a warning like in torch.div
could be useful?
Would be nice if triton // behaved like numpy but I appreciate it's a pretty severe breaking change
Not sure if we will go this direction, but will let you know once we've got a solution. We can keep this thread open.
❓ Question
Both numpy and python do integer division rounding towards minus infinity (floor division).
Torch does round toward 0 (like triton) but prints an explicit warning about this being deprecated. I think the current behaviour is rather unexpected. Maybe there should be a function (like
torch.div
) that allows the user to explicitly decide on the rounding behaviour? My current use case I very much rely on floor division, even in the case of negative numbers.Versions
triton==2.0.0.dev20221025
torch==1.12.1+cu116
Reproducer