FlagOpen / FlagGems

FlagGems is an operator library for large language models implemented in Triton Language.
Apache License 2.0
296 stars 27 forks source link

[Bugfix] scatter&gather's accuracy test #202

Open GwokHiujin opened 1 month ago

GwokHiujin commented 1 month ago

This PR increases the BLOCK_M sizes in the offset_calculator function to prevent excessively small BLOCK_M values, which may resolve the correctness issues encountered during local testing.


If CI runs smoothly, I can speculate on the potential causes of this confusing issue...🧐

The correctness problem was likely due to insufficient block size during tensor processing, resulting in incomplete or inaccurate memory accesses and inconsistent results. By increasing the BLOCK_M size, we ensure that larger data blocks are processed more consistently, reducing the risk of edge-case errors in tensor indexing.

GwokHiujin commented 1 month ago

The following test case fails. It may require a more robust autotuning.

import torch
import flag_gems
x = torch.ones(1000000000, dtype=torch.float16, device='cuda')
idx = torch.arange(1000000000, device='cuda')
y = flag_gems.ops.gather(x, 0, idx)

So it seems that we just cannot allow the kernel grid's row numbers to be too large! However, since we are planning to use codegen for scatter, gather, and related operators, the offset calculations will shift into the kernels, potentially making this version out-of-date. Should we proceed with this patch now?

tongxin commented 1 month ago

The following test case fails. It may require a more robust autotuning.

import torch
import flag_gems
x = torch.ones(1000000000, dtype=torch.float16, device='cuda')
idx = torch.arange(1000000000, device='cuda')
y = flag_gems.ops.gather(x, 0, idx)

So it seems that we just cannot allow the kernel grid's row numbers to be too large! However, since we are planning to use codegen for scatter, gather, and related operators, the offset calculations will shift into the kernels, potentially making this version out-of-date. Should we proceed with this patch now?

Let's shift to the fused kernel, shall we?