Open lizelive opened 1 year ago
Please copy and paste your full script?
import torch
import timeit
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)
# torch.set_default_dtype(torch.float32)
DIST_TO_BLACKHOLE = 2.52e20
GALAXY_RADIUS = 4.7e20
STARS_PER_CUBIC_METER = 9.7e-50
# need to be at no closer than 100m apart for int64
# f32 works up to 10km
# i think 1024m per tile is probably good
# the ever given is 400m long
# 20,000m is reall big scifi ship
# death star is 200,000m (def not a moon)
# conclusion: probably want to use i64 + f64 for position
FINE_FMT = dict(dtype=torch.float32)
CORSE_FMT = dict(dtype=torch.int32)
GRID_SIZE = 1024 # m, about 1km
CHANNELS = 3 # lv, lpf, lpc
DIMS = 3
COUNT = 2 ** 20
MAX_SPEED = 0.01
NUM_STEPS = 10_000
CORSE_SPAWN_RADIUS = 1024 # 1Mm
DELTA_TIME = 0.1
def init_storage():
storage = torch.zeros((COUNT, DIMS, CHANNELS), dtype=torch.float32)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# storage = torch.UntypedStorage(size = 3 * DIMS * COUNT, device= device)
print(storage.stride())
linpos_fine, linpos_corse, linvel = storage.unbind(-1)
linpos_corse = linpos_corse.view(**CORSE_FMT)
# linpos_corse = torch.zeros_like(linpos_fine)
# linvel = torch.zeros_like(linvel)
# linpos_corse = torch.zeros_like(linpos_corse)
print(linpos_corse.stride())
return linpos_fine, linpos_corse, linvel
def init_state(linpos_fine, linpos_corse, linvel):
linpos_fine.set_((GRID_SIZE * torch.rand(COUNT, DIMS)).to(**FINE_FMT))
# linpos_corse.set_(torch.randint(
# -CORSE_SPAWN_RADIUS, CORSE_SPAWN_RADIUS, (COUNT, DIMS), **CORSE_FMT
# ))
# linvel.set_((
# MAX_SPEED * torch.nn.functional.normalize(torch.rand(COUNT, DIMS) - 1 / 2)
# ).to(**FINE_FMT))
linpos_fine, linpos_corse, linvel = init_storage()
# start timing
init_state(linpos_fine, linpos_corse, linvel)
def divmod_torch(x, y):
return x // y, x % y
# twice as fast by not using integer division and using floor instead
# @torch.compile(fullgraph=True)
def naive_step_(
linvel: torch.Tensor,
linpos_fine: torch.Tensor,
linpos_corse: torch.Tensor,
delta_time: float,
):
# linpos_fine.add_(linvel, alpha=delta_time)
linpos_fine += linvel * delta_time
# steps, linpos_fine = divmod(linpos_fine, GRID_SIZE)
# .to(torch.int32)
steps = linpos_fine.to(torch.int32) // GRID_SIZE
linpos_corse += steps
linpos_fine -= steps * GRID_SIZE
import triton
import triton.language as tl
import triton.language.libdevice as libdevice
@triton.jit
def step_kernel(v_ptr, f_ptr, c_ptr, n_elements, delta_time, BLOCK_SIZE: tl.constexpr):
"""
C is corse
F is fine
V is velocity
delta_time is delta time
"""
# row index
m = tl.program_id(0)
# col indices
block_start = m * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
c_ptr += offsets
f_ptr += offsets
v_ptr += offsets
# mask = None
c = tl.load(c_ptr)
f = tl.load(f_ptr)
v = tl.load(v_ptr)
f += v * delta_time
# s = tl.zeros(f.shape, dtype=tl.int32)
s = f.to(dtype=tl.int32)
# f = libdevice.remquo(f, GRID_SIZE, s)
f -= s
c += s
tl.store(f_ptr, f)
tl.store(c_ptr, c)
def triton_step_(
linvel: torch.Tensor,
linpos_fine: torch.Tensor,
linpos_corse: torch.Tensor,
delta_time,
):
n_elements = linvel.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
step_kernel[grid](
linvel, linpos_fine, linpos_corse, n_elements, delta_time, BLOCK_SIZE=1024
)
# BLOCK_SIZE = triton.next_power_of_2(min(n_elements, )
USE_TRITON = True
step_ = triton_step_ if USE_TRITON else naive_step_
with torch.no_grad():
for i in range(NUM_STEPS):
step_(linvel, linpos_fine, linpos_corse, DELTA_TIME)
start = timeit.default_timer()
print(linpos_corse.dtype)
for i in range(NUM_STEPS):
step_(linvel, linpos_fine, linpos_corse, DELTA_TIME)
# linpos_fine %= GRID_SIZE
end = timeit.default_timer()
# compute the number of operations per second
elapsed = end - start
ops_per_second = COUNT * NUM_STEPS / elapsed
print(f"{ops_per_second:.2e} ops/s")
remquo is added in triton.language.libdevice
Triton triton.language.libdevice
has been deprecated
__nv_remquof __nv_remquo
im unsure how to add it because none of the other functions seem to have multiple outputs
results in