FlagOpen / FlagGems

FlagGems is an operator library for large language models implemented in Triton Language.
Apache License 2.0
296 stars 27 forks source link

add equal_nan=False for assert_equal #225

Closed iclementine closed 1 week ago

iclementine commented 3 weeks ago

PR Category

OP Test

Type of Change

Bug Fix

Description

Add equal_nan=False for assert_equal, this is neede for testing for equality of floating point arrays. The implementation of flag_gems.testing.assert_equal now uses torch.testing.assert_close with no tolerance.

This is useful when we need to testing for equality between floating point arrays, where we believe that all the elements should actually be integers, like the result of floor_div, floor, ceil, round. However, there maybe inf, -inf, nan in these arrays, since they cannot be properly converted to an integer, they are preserved.

To test for equality while ignoring nans, equal_nan is added.

Fix floor_divide. Follow CPython/numpy/torch's implementation to get the same result.

Issue

Progress

Performance