scikit-hep / ragged

Manipulating ragged arrays in an Array API compliant way.
https://data-apis.org/array-api/latest/API_specification
BSD 3-Clause "New" or "Revised" License
30 stars 6 forks source link

feature: ragged.sort and ragged.argsort #48

Closed jpivarski closed 6 months ago

jpivarski commented 6 months ago

I thought, "The ragged library is the most straightforward project we have, proceeding from a specification, statically typed everything, and pre-populated docstrings. If the the Glorified Autocomplete is good at anything, it should be good at filling in these NotImplementedError("TODO") stubs, right?"

Well... even with a context window full of the already-implemented functions, it wrote

    (impl,) = _unbox(x)
    kind = 'stable' if stable else 'quicksort'
    order = 'F' if descending else 'C'
    if isinstance(impl, ak.Array):
        sorted_indices = ak.argsort(impl, axis=axis, ascending=not descending, stable=stable)
    else:
        sorted_indices = np.argsort(impl, axis=axis, kind=kind, order=order)
    return _box(type(x), sorted_indices)

for argsort and

    (impl,) = _unbox(x)
    if isinstance(impl, ak.Array):
        sorted_array = ak.sort(impl, axis=axis, ascending=not descending, stable=stable)
    else:
        # NumPy does not directly support stable sorting in descending order, so we need to handle it ourselves
        if descending:
            sorted_array = np.sort(impl, axis=axis, kind='stable' if stable else 'quicksort')[::-1]
        else:
            sorted_array = np.sort(impl, axis=axis, kind='stable' if stable else 'quicksort')
    return _box(type(x), sorted_array)

for sort.

What's with the Fortran ('F') and C-contiguous ('C') ordering? That's for memory layout, and none of the other functions have that. The confusion between "order" and "sorting" makes sense in the context of natural language, but they don't have anything to do with each other here. (Actually, NumPy sorting functions have an order, which refers to fields in a structured array—not memory-contiguousness!)

I probably should have given it Awkward's argsort and sort docstrings. It guessed the wrong interface.

Maybe it was natural to branch and do both Awkward and NumPy implementations, but in ragged, the alternative of an Awkward Array is a NumPy or CuPy scalar, which doesn't make sense to sort.

Finally, it's true that NumPy doesn't have a descending flag. (ChatGPT must have known that from its training dataset. It's the sort of thing people would warn each other about on StackOverflow.) But it only seemed to know that on the second of two functions.

Okay: coding (in the most straightforward instance I can find) is a bust. What about writing tests?

@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("data, expected_indices", [
    ([3, 1, 2], [1, 2, 0]),
    ([[1, 0], [3, 2]], [[1, 0], [1, 0]]),
    ([[5, 7], [], [3, 1]], [[1, 0], [], [1, 0]]),
])
def test_argsort(device, data, expected_indices):
    x = ragged.array(data, device=device)
    result = ragged.argsort(x)
    assert result.tolist() == expected_indices, "Argsort indices mismatch."

@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("data, expected_sorted", [
    ([3, 1, 2], [1, 2, 3]),
    ([[3, 2], [1, 0]], [[2, 3], [0, 1]]),
    ([[5, 7], [], [3, 1]], [[5, 7], [], [1, 3]]),
])
def test_sort(device, data, expected_sorted):
    x = ragged.array(data, device=device)
    result = ragged.sort(x)
    assert result.tolist() == expected_sorted, "Sort order mismatch."

@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("descending", [False, True])
@pytest.mark.parametrize("stable", [False, True])
def test_argsort_sort_options(device, descending, stable):
    data = [2, 3, 1, 4]
    expected_indices_desc = [3, 1, 0, 2] if descending else [2, 0, 1, 3]
    expected_sorted_desc = [4, 3, 2, 1] if descending else [1, 2, 3, 4]

    x = ragged.array(data, device=device)
    result_indices = ragged.argsort(x, descending=descending, stable=stable)
    result_sorted = ragged.sort(x, descending=descending, stable=stable)

    assert result_indices.tolist() == expected_indices_desc, f"Argsort descending={descending}, stable={stable} failed."
    assert result_sorted.tolist() == expected_sorted_desc, f"Sort descending={descending}, stable={stable} failed."

On the first two, the the expected_sorted has a different depth of nesting than the data, so it couldn't possibly be right. But maybe the problem is that the inputs and expected outputs are all written in a parametrize decorator—do people actually do that? I find that it helps to have inputs and expected outputs near each other, clearly lined up, and if they're separate arguments to the parametrize decorator, they can't be lined up.

Speaking of which,

assert result.tolist() == expected_indices, "Argsort indices mismatch."

Really? You needed the explanation, instead of letting the assertion stand on its own? I suspect that this sort of thing comes from bad practices in the wild:

# In this comment, I will tell you what is happening in the next line. It is a comment.
# This is the comment I referred to on the previous line. In this comment, I will tell
# you about the code on the next line. It is a print statement. It prints, "hello world".
print("hello world")
# The previous line printed "hello world" to the screen, terminal, Jupyter notebook, or
# other output device where printed lines are printed.

The third test looks right, but why not test the options in all tests? And how about different axis values? (I gave ChatGPT the statically typed, docstringed implementation of the functions it was supposed to test.)

All in all, it's a bust. But it's funny to be complaining about LLMs not being able to write code for me when I was so blown away by them being able to produce anything meaningful just 10 years ago, in The Unreasonable Effectiveness of Recurrent Neural Networks. I suppose I'm spoiled.