Open peterbell10 opened 2 months ago
Just to confirm, the TritonGPU IR is generated from valid Triton python code?
It's came from the lowering from a new operator I'm adding, but I'll see if I can reproduce with an existing operator.
This produces the same error on the current master branch
import triton.language as tl
import triton
import torch
@triton.jit
def test_fn(out_ptr, a_ptr, workspace, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
desc_ptr = workspace
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=desc_ptr, global_address=a_ptr, load_size=[4, N_BLOCK], global_size=[M, N], element_ty=a_ptr.dtype.element_ty)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_ptr)
gather = tl._experimental_descriptor_load(desc_ptr, [0, 0], [4, N_BLOCK], a_ptr.dtype.element_ty)
tl.store(out_ptr + tl.arange(0, 4)[:, None] * N_BLOCK + tl.arange(0, N_BLOCK)[None, :], gather)
out = torch.empty((4, 128), dtype=torch.float32, device="cuda")
inp = torch.arange(4 * 128, dtype=torch.float32, device="cuda").reshape(4, 128)
workspace = torch.empty(128, dtype=torch.uint8, device="cuda")
test_fn[(1,)](out, inp, workspace, 4, 128, 4, 128)
I'll take a look today
Reopening this as it seems the TMA hardware does support swizzling with only 4 rows of data.
I get this result if it's helpful:
unswizzled:
tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.,
12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.,
24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35.,
36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.,
48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59.,
60., 61., 62., 63.],
[128., 129., 130., 131., 132., 133., 134., 135., 136., 137., 138., 139.,
140., 141., 142., 143., 144., 145., 146., 147., 148., 149., 150., 151.,
152., 153., 154., 155., 156., 157., 158., 159., 160., 161., 162., 163.,
164., 165., 166., 167., 168., 169., 170., 171., 172., 173., 174., 175.,
176., 177., 178., 179., 180., 181., 182., 183., 184., 185., 186., 187.,
188., 189., 190., 191.],
[256., 257., 258., 259., 260., 261., 262., 263., 264., 265., 266., 267.,
268., 269., 270., 271., 272., 273., 274., 275., 276., 277., 278., 279.,
280., 281., 282., 283., 284., 285., 286., 287., 288., 289., 290., 291.,
292., 293., 294., 295., 296., 297., 298., 299., 300., 301., 302., 303.,
304., 305., 306., 307., 308., 309., 310., 311., 312., 313., 314., 315.,
316., 317., 318., 319.],
[384., 385., 386., 387., 388., 389., 390., 391., 392., 393., 394., 395.,
396., 397., 398., 399., 400., 401., 402., 403., 404., 405., 406., 407.,
408., 409., 410., 411., 412., 413., 414., 415., 416., 417., 418., 419.,
420., 421., 422., 423., 424., 425., 426., 427., 428., 429., 430., 431.,
432., 433., 434., 435., 436., 437., 438., 439., 440., 441., 442., 443.,
444., 445., 446., 447.]], device='cuda:0', dtype=torch.float16)
swizzled:
tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.,
12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.,
24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35.,
36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.,
48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59.,
60., 61., 62., 63.],
[136., 137., 138., 139., 140., 141., 142., 143., 128., 129., 130., 131.,
132., 133., 134., 135., 152., 153., 154., 155., 156., 157., 158., 159.,
144., 145., 146., 147., 148., 149., 150., 151., 168., 169., 170., 171.,
172., 173., 174., 175., 160., 161., 162., 163., 164., 165., 166., 167.,
184., 185., 186., 187., 188., 189., 190., 191., 176., 177., 178., 179.,
180., 181., 182., 183.],
[272., 273., 274., 275., 276., 277., 278., 279., 280., 281., 282., 283.,
284., 285., 286., 287., 256., 257., 258., 259., 260., 261., 262., 263.,
264., 265., 266., 267., 268., 269., 270., 271., 304., 305., 306., 307.,
308., 309., 310., 311., 312., 313., 314., 315., 316., 317., 318., 319.,
288., 289., 290., 291., 292., 293., 294., 295., 296., 297., 298., 299.,
300., 301., 302., 303.],
[408., 409., 410., 411., 412., 413., 414., 415., 400., 401., 402., 403.,
404., 405., 406., 407., 392., 393., 394., 395., 396., 397., 398., 399.,
384., 385., 386., 387., 388., 389., 390., 391., 440., 441., 442., 443.,
444., 445., 446., 447., 432., 433., 434., 435., 436., 437., 438., 439.,
424., 425., 426., 427., 428., 429., 430., 431., 416., 417., 418., 419.,
420., 421., 422., 423.]], dtype=torch.float16)
I think the problem is on this line int tileRows = 8;
I'll try to address it tomorrow
I am running into an assertion error in the codegen for
local_load
which is coming from the linear layouts code. Here is a minified reproducerWhen lowering to llvm ir it fails with the following error
cc @Jokeren @jlebar