triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.12k stars 1.6k forks source link

add libdevice.remquo #2176

Open lizelive opened 1 year ago

lizelive commented 1 year ago

Compute a double-precision floating-point remainder in the same way as the remainder() function. Argument quo returns part of quotient upon division of x by y. Value quo has the same sign as and may not be the exact quotient but agrees with the exact quotient in the low order 3 bits.

__nv_remquof __nv_remquo

im unsure how to add it because none of the other functions seem to have multiple outputs

@impl.extern
def remquo(arg0, arg1, arg2, _builder=None):
    return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ],
                              {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("int32"),): ("__nv_remquof", core.dtype("fp32")),
                               (core.dtype("fp64"), core.dtype("fp64"), core.dtype("int64"),): ("__nv_remquo", core.dtype("fp64")),
                               }, _builder)

results in

ValueError: input arg type does not match.Expect one of dict_keys([(triton.language.fp32, triton.language.fp32, triton.language.int32), (triton.language.fp64, triton.language.fp64, triton.language.int64)]), got (triton.language.fp32, triton.language.fp32, triton.language.fp32)
Jokeren commented 1 year ago

Please copy and paste your full script?

lizelive commented 1 year ago
import torch
import timeit

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)
# torch.set_default_dtype(torch.float32)

DIST_TO_BLACKHOLE = 2.52e20
GALAXY_RADIUS = 4.7e20
STARS_PER_CUBIC_METER = 9.7e-50

# need to be at no closer than 100m apart for int64
# f32 works up to 10km
# i think 1024m per tile is probably good
# the ever given is 400m long
# 20,000m is reall big scifi ship
# death star is 200,000m (def not a moon)
# conclusion: probably want to use i64 + f64 for position
FINE_FMT = dict(dtype=torch.float32)
CORSE_FMT = dict(dtype=torch.int32)

GRID_SIZE = 1024  # m, about 1km

CHANNELS = 3 # lv, lpf, lpc
DIMS = 3
COUNT = 2 ** 20

MAX_SPEED = 0.01
NUM_STEPS = 10_000

CORSE_SPAWN_RADIUS = 1024  # 1Mm
DELTA_TIME = 0.1

def init_storage():
    storage = torch.zeros((COUNT, DIMS, CHANNELS), dtype=torch.float32)

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # storage = torch.UntypedStorage(size = 3 * DIMS * COUNT, device= device)

    print(storage.stride())

    linpos_fine, linpos_corse, linvel = storage.unbind(-1)
    linpos_corse = linpos_corse.view(**CORSE_FMT)

    # linpos_corse = torch.zeros_like(linpos_fine)
    # linvel = torch.zeros_like(linvel)
    # linpos_corse = torch.zeros_like(linpos_corse)
    print(linpos_corse.stride())

    return linpos_fine, linpos_corse, linvel

def init_state(linpos_fine, linpos_corse, linvel):
    linpos_fine.set_((GRID_SIZE * torch.rand(COUNT, DIMS)).to(**FINE_FMT))
    # linpos_corse.set_(torch.randint(
    #     -CORSE_SPAWN_RADIUS, CORSE_SPAWN_RADIUS, (COUNT, DIMS), **CORSE_FMT
    # ))
    # linvel.set_((
    #     MAX_SPEED * torch.nn.functional.normalize(torch.rand(COUNT, DIMS) - 1 / 2)
    # ).to(**FINE_FMT))

linpos_fine, linpos_corse, linvel = init_storage()
# start timing
init_state(linpos_fine, linpos_corse, linvel)

def divmod_torch(x, y):
    return x // y, x % y

# twice as fast by not using integer division and using floor instead
# @torch.compile(fullgraph=True)
def naive_step_(
    linvel: torch.Tensor,
    linpos_fine: torch.Tensor,
    linpos_corse: torch.Tensor,
    delta_time: float,
):
    # linpos_fine.add_(linvel, alpha=delta_time)
    linpos_fine += linvel * delta_time
    # steps, linpos_fine = divmod(linpos_fine, GRID_SIZE)
    # .to(torch.int32)
    steps = linpos_fine.to(torch.int32) // GRID_SIZE
    linpos_corse += steps
    linpos_fine -= steps * GRID_SIZE

import triton
import triton.language as tl
import triton.language.libdevice as libdevice

@triton.jit
def step_kernel(v_ptr, f_ptr, c_ptr, n_elements, delta_time, BLOCK_SIZE: tl.constexpr):
    """
    C is corse
    F is fine
    V is velocity
    delta_time is delta time
    """
    # row index
    m = tl.program_id(0)
    # col indices
    block_start = m * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    c_ptr += offsets
    f_ptr += offsets
    v_ptr += offsets

    # mask = None
    c = tl.load(c_ptr)
    f = tl.load(f_ptr)
    v = tl.load(v_ptr)

    f += v * delta_time

    # s = tl.zeros(f.shape, dtype=tl.int32)
    s = f.to(dtype=tl.int32)

    # f = libdevice.remquo(f, GRID_SIZE, s)
    f -= s
    c += s

    tl.store(f_ptr, f)
    tl.store(c_ptr, c)

def triton_step_(
    linvel: torch.Tensor,
    linpos_fine: torch.Tensor,
    linpos_corse: torch.Tensor,
    delta_time,
):
    n_elements = linvel.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    step_kernel[grid](
        linvel, linpos_fine, linpos_corse, n_elements, delta_time, BLOCK_SIZE=1024
    )

    # BLOCK_SIZE = triton.next_power_of_2(min(n_elements, )

USE_TRITON = True
step_ = triton_step_ if USE_TRITON else naive_step_

with torch.no_grad():
    for i in range(NUM_STEPS):
        step_(linvel, linpos_fine, linpos_corse, DELTA_TIME)

    start = timeit.default_timer()
    print(linpos_corse.dtype)
    for i in range(NUM_STEPS):
        step_(linvel, linpos_fine, linpos_corse, DELTA_TIME)
        # linpos_fine %= GRID_SIZE

    end = timeit.default_timer()
# compute the number of operations per second
elapsed = end - start
ops_per_second = COUNT * NUM_STEPS / elapsed
print(f"{ops_per_second:.2e} ops/s")

remquo is added in triton.language.libdevice

Jokeren commented 1 year ago

Triton triton.language.libdevice has been deprecated