triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.54k stars 1.67k forks source link

`//` operator rounds towards zero #955

Open CHDev93 opened 1 year ago

CHDev93 commented 1 year ago

❓ Question

Both numpy and python do integer division rounding towards minus infinity (floor division).

>> (-1) // 2
-1
>> import numpy as np
>> np.arange(-5,-1) // 2
array([-3, -2, -2, -1])

Torch does round toward 0 (like triton) but prints an explicit warning about this being deprecated. I think the current behaviour is rather unexpected. Maybe there should be a function (like torch.div) that allows the user to explicitly decide on the rounding behaviour? My current use case I very much rely on floor division, even in the case of negative numbers.

Versions

Reproducer

import torch
import triton
import triton.language as tl

@triton.jit
def _integer_div(
    output_ptr,
    n_rows,
    n_cols,
    w,
    r,
    stride_xn,
    stride_xm,
    BLOCK_SIZE: tl.constexpr,
):
    # compute memory offsets of elements handled by this instance
    row_id = tl.program_id(axis=0)
    col_id = tl.program_id(axis=1)

    row_offsets = row_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    col_offsets = col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    x_offsets = row_offsets[:, None] * stride_xn + col_offsets[None, :] * stride_xm

    # load data from x
    mask = (row_offsets[:, None] < n_rows) & (col_offsets[None, :] < n_cols)

    # write-back
    tl.store(output_ptr + x_offsets, (-w) // r, mask=mask)

def integer_div(x: torch.Tensor) -> torch.Tensor:
    BLOCK_SIZE = 16
    output = torch.empty_like(x)
    assert x.is_contiguous()

    n_rows, n_cols = x.shape
    grid_x = triton.cdiv(n_rows, BLOCK_SIZE)
    grid_y = grid_x
    grid = (grid_x, grid_y)

    r = 5
    w = 8

    expected_answer = (-w) // r
    print(f"{expected_answer=}")

    _integer_div[grid](
        output,
        n_rows,
        n_cols,
        w,
        r,
        x.stride(0),
        x.stride(1),
        BLOCK_SIZE=BLOCK_SIZE,
    )

    return output

def main() -> None:
    x = torch.rand(size=(16, 16)).cuda()
    out = integer_div(x)
    triton_answer = out[0, 0].item()
    print(f"{triton_answer=}")

main()
# Program output
# expected_answer=-2
# triton_answer=-1.0
Jokeren commented 1 year ago

Use libdevice https://github.com/openai/triton/blob/master/python/triton/language/libdevice.py

Jokeren commented 1 year ago

See functions like div_rn, div_rz, etc.

CHDev93 commented 1 year ago

Thanks! I have never come across this libdevice module before and it looks really useful.

For my particular use case I'm trying to do integer division on arrays (trying to compute a mask based on the row and column indices). Looks like ddiv_rd is almost what I want except it expects the inputs to be doubles.

    raise ValueError(f"input arg type does not match."
ValueError: input arg type does not match.Expect one of dict_keys([(triton.language.fp64, triton.language.fp64)]), got (triton.language.int32, triton.language.int32)

My work around for now is to do something like

def div_rd(a, b):
  return (a - (b - 1)) // b 

I do think having // behave differently than numpy or python is a bit surprising in the first place. Negative indices come up pretty often for me when creating masks for my tilings

Jokeren commented 1 year ago

You can convert a and b to fp64 or fp32 first before applying the libdevice function.

Jokeren commented 1 year ago

Please let me be more specific. I shouldn't assume that something not in the doc is obvious.

https://github.com/openai/triton/blob/8650b4d1cbc750d659156e2c17a058736614827b/python/test/unit/language/test_core.py#L1098

Likewise, you can use a.to(tl.float32) to convert the data type.

CHDev93 commented 1 year ago

Thanks for the code snippet! Decided to go with (a - (b - 1)) // b to avoid all the various casting. Would be nice if triton // behaved like numpy but I appreciate it's a pretty severe breaking change. Maybe a warning like in torch.div could be useful?

Jokeren commented 1 year ago

Would be nice if triton // behaved like numpy but I appreciate it's a pretty severe breaking change

Not sure if we will go this direction, but will let you know once we've got a solution. We can keep this thread open.