PyO3 / pyo3

Rust bindings for the Python interpreter
https://pyo3.rs
Apache License 2.0
12.57k stars 776 forks source link

Add locked iterations APIs for dicts and lists #4571

Open ngoldbaum opened 2 months ago

ngoldbaum commented 2 months ago

See https://github.com/PyO3/pyo3/pull/4439 and https://github.com/PyO3/pyo3/pull/4539 for context.

pyo3 follows Python's behavior for multithreaded dict and list iteration and allows race conditions (see my experiment here: https://github.com/PyO3/pyo3/pull/4539#discussion_r1763722120).

Unfortunately as a consequence in order to preserve this behavior on the free-threaded build, we will need to use slower owned reference APIs for lists (#4539) and apply a critical sections for dicts in each loop iteration (#4439) since there are no equivalent locking owned reference iteration APIs for dicts.

It would be nice in both cases if we could simply lock the dict or list while we are iterating over it, which would allow us to use the faster APIs that access list and dict internals. However this would make the semantics for iteration via pyo3 different than via python.

Instead, I think we should add a new locked_iter function to PyDictMethods and PyListMethods. Maybe PyAnyMethods too but users can add their own locking if they want too for the generic case, we need to do it for dict and list for in PyO3 for performance reasons.

Instead of directly returning an iterator, instead users would pass in a closure that accepts an iterator and work with the iterator inside the closure. My understanding from @davidhewitt is that using a closure would ensure exactly one Py_END_CRITICAL_SECTION follows each Py_BEGIN_CRITICAL_SECTION, even if a panic happens and even if the critical sections are recursive.

I'm not terribly experienced with writing rust APIs that take closures, but I think the API I'm looking for is this?

fn locked_iter<F>(&self, closure: F) -> PyResult<()>
where
    F: Fn(PyDictLockedIterator<'py>) -> PyResult<()>

There PyDictLockedIterator would be like PyDictIterator, but its use would implicitly imply a critical section is held and the dict is locked. And of course a similar API for lists.

This is something we could add for PyO3 0.24, and for 0.23 we'd merge the two open PRs related to this and only have the slow, safe iteration for the free-threaded build.

Ping @bschoenmaeckers

bschoenmaeckers commented 2 months ago

Thanks for moving this forward. I will pick this up next week so we can merge that MR.

davidhewitt commented 2 months ago

I think the closure needs to be similar to for_each which takes each element in turn, so that that way we get to call PyDict_Next within our own code and can be sure that users can't nest critical sections.

e.g. something like this:

fn locked_for_each<F>(&self, closure: F) -> PyResult<()>
where
    F: Fn(Bound<'py, PyAny>, Bound<'py, PyAny>) -> PyResult<()>

... that said, now that I say this, I realise that what we should probably do is override Iterator::fold and Iterator::try_fold to lock the critical section inside of them, which would automatically optimize many uses of PyO3 iterators (e.g. .iter().map(...).sum()).

ngoldbaum commented 2 months ago

Hmm, can we actually use try_fold? I tried to write an implementation for PyList and I think in order to actually implement it, you need to use the Try trait, which isn't stable yet.

davidhewitt commented 2 months ago

Doh, of course not. We can still implement fold, and maybe on our nightly feature we could implement try_fold.

ngoldbaum commented 4 days ago

https://github.com/PyO3/pyo3/pull/4439 did this for dict, but we should do this for list as well.

ngoldbaum commented 22 hours ago

I'm finally getting around to implementing this for list, following the example for dict. Should have a PR ready in the next day or two.