python / cpython

The Python programming language
https://www.python.org
Other
62.35k stars 29.94k forks source link

Dataclasses - Improve the performance of `_dataclass_{get,set}state` #103032

Open sobolevn opened 1 year ago

sobolevn commented 1 year ago

Feature or enhancement

I've noticed this comment yesterday: https://github.com/python/cpython/blob/1fd603fad20187496619930e5b74aa7690425926/Lib/dataclasses.py#L1128-L1131

So, out of curiosity I've decided to try this out. How fast can I make it? The results are here:

Testing ZeroFields
Dump time (before): 0.229174 sec
Dump time (after) : 0.147407 sec

Testing OneField
Dump time (before): 0.260065 sec
Dump time (after) : 0.197960 sec

Testing EightFields
Dump time (before): 0.331278 sec
Dump time (after) : 0.192171 sec

Testing NestedFields
Dump time (before): 0.846335 sec
Dump time (after) : 0.470739 sec

Here's the simple benchmark that I was using:

import pickle
from dataclasses import dataclass
from timeit import timeit

def perf(e):
    print('Testing', e.__class__.__name__)

    d = pickle.dumps(e)

    iterations = 10000
    t1 = timeit(lambda: pickle.dumps(e), number=iterations)
    t2 = timeit(lambda: pickle.loads(d), number=iterations)
    print(f'Dump time: {t1:.6f} sec')
    print(f'Load time: {t2:.6f} sec')
    print()

@dataclass(frozen=True, slots=True)
class ZeroFields:
    pass

perf(ZeroFields())

@dataclass(frozen=True, slots=True)
class OneField:
    foo: str

perf(OneField('l'))

@dataclass(frozen=True, slots=True)
class EightFields:
    foo: str
    bar: int
    baz: int
    spam: list[int]
    eggs: dict[str, str]
    x: bool
    y: bool
    z: bool

perf(EightFields(
    "a", 1, 2, [1, 2, 3, 4, 5], {'a': 'a', 'b': 'b', 'c': 'c'},
    True, False, True,
))

@dataclass(frozen=True, slots=True)
class NestedFields:
    z1: ZeroFields
    z2: ZeroFields
    e1: EightFields
    e2: EightFields

e = EightFields(
    "a", 1, 2, [1, 2, 3, 4, 5], {'a': 'a', 'b': 'b', 'c': 'c'},
    True, False, True,
)
perf(NestedFields(ZeroFields(), ZeroFields(), e, e))

Here's the very rough version of what I am planning to do:

        cls.__getstate__ = _create_fn(
            '__getstate__',
            ['self'],
            [f"return ({', '.join(f'self.{f.name}' for f in fields(cls))})"],
        )

Things to do:

  1. Refactor current example code to be inline with other code generators
  2. Add __setstate__ similar support

Does it look like a good enough speed up to make this change?

CC @ericvsmith and @carljm

sobolevn commented 1 year ago

Final timings:

Testing ZeroFields
-- before
Dump time: 0.230313 sec
Load time: 0.174264 sec
-- after
Dump time: 0.149132 sec
Load time: 0.105592 sec

Testing OneField
-- before
Dump time: 0.275664 sec
Load time: 0.219554 sec
-- after
Dump time: 0.164988 sec
Load time: 0.123629 sec

Testing EightFields
-- before
Dump time: 0.340355 sec
Load time: 0.355984 sec
-- after
Dump time: 0.181618 sec
Load time: 0.242010 sec

Testing NestedFields
-- before
Dump time: 0.840898 sec
Load time: 0.779367 sec
-- after
Dump time: 0.447484 sec
Load time: 0.453657 sec

But, since we now do more work during dataclass creation, I wanted to measure this effect as well. Here's my small benchmark:

from dataclasses import dataclass
from timeit import timeit

def create_zero():
    @dataclass(frozen=True, slots=True)
    class ZeroFields:
        pass
    return ZeroFields

ZeroFields = create_zero()

def create_one():
    @dataclass(frozen=True, slots=True)
    class OneField:
        foo: str

def create_eight():
    @dataclass(frozen=True, slots=True)
    class EightFields:
        foo: str
        bar: int
        baz: int
        spam: list[int]
        eggs: dict[str, str]
        x: bool
        y: bool
        z: bool
    return EightFields

EightFields = create_eight()

def create_nested():
    @dataclass(frozen=True, slots=True)
    class NestedFields:
        z1: ZeroFields
        z2: ZeroFields
        e1: EightFields
        e2: EightFields

for f in [create_zero, create_one, create_eight, create_nested]:
    print("Testing", f.__name__)
    res = timeit(f, number=100)
    print('Result', res)
    print()

Here are the results:

Testing create_zero
-- before
Result 0.19914910700026667
-- after
Result 0.2348911329972907

Testing create_one
-- before
Result 0.22826643000007607
-- after
Result 0.2645908749982482

Testing create_eight
-- before
Result 0.3498685700033093
-- after
Result 0.4295154959982028

Testing create_nested
-- before
Result 0.41222004499286413
-- after
Result 0.4236956989989267

So, we are basically trading "startup time" with "runtime time". I am not sure which one is more important here.

One more thing, notice numbers=10000 in the pickle benchmark, but the second one has only numbers=100. So, if we care about the absolute time - I guess making the dataclass creation slower is not worth it.

Here's my final patch:

cls.__getstate__, cls.__setstate__ = _dataclass_states(cls)

and:

def _dataclass_states(cls):
    getters = []
    setters = []
    for index, f in enumerate(fields(cls)):
        getters.append(f'self.{f.name}')
        setters.append(f'object.__setattr__(self, "{f.name}", state[{index}])')

    getstate = _create_fn(
        '__getstate__',
        ('self',),
        [f'return ({", ".join(getters)})'],
    )
    setstate = _create_fn(
        '__setstate__',
        ('self', 'state'),
        setters if setters else ['pass'],
    )
    return getstate, setstate

Please, share your feedback and ideas.

pochmann commented 1 year ago

I guess making the dataclass creation slower is not worth it.

Maybe an attrgetter prebuilt at data class creation would have both fast creation and fast application? Benchmark for application, using EightFields:

 259 ±  3 ns  hardcoded
 334 ±  4 ns  attrgetter_prebuilt
 412 ±  6 ns  attrgetter_prebuilt_wrapped
2377 ± 42 ns  current
2873 ± 40 ns  attrgetter_on_the_fly

3.10.6 (main, Jan  7 2023, 10:15:17) [GCC 12.2.0]

Note that attrgetter_on_the_fly includes creation time, so is only included to show that the creation of the attrgetter is fast. And attrgetter_prebuilt_wrapped is included because I'm just not sure whether attrgetter_prebuilt can be used directly, like cls.__getstate__ = attrgetter(...).

Code [Attempt This Online!](https://ato.pxeger.com/run?1=lVS9btwwDJ66-CmILLYC9dqkCdAaOKBL8wTdDMOQbfnOqC0ZkpzGF9yTdMnSLp36Cn2MPk2oH58viJd6sEiKIj9-FPXj1zCZvRRPTz9H07z9-O_N70bJHmpmWNUxrbmGth-kMouJQtPyrtaR8zRtz1szO3nN78iBK2akmveYMWrHjeHK72vDTKtNW51S9JwJivaa30fBpCcdRZ9PuRM8eeBi-1WNHD07abSTSeS24Uu725s7By-NAL9GyhQjKqeUTKXQChOUw6LogfUpdAgnQ1PubHy30ynUbWUyDGBxKb_xkEIpZefk6Uw-BDmycGF7DiZxDhfsgsIVhWsKmV8-ULihcJtTeIxZnAL-KMSllUorVVaq4iN1533Vd6zTuDglIpiNN1CNSnFhEs27hvjKFTejEpAh5ZZ5t4Wt2wjWc4K8KGiw_NBLfzD3wfZM1ZWseb0Szlo2SCr1EjJ6kg5BslwG0VIYxIewTmE9-FzLpSikKMyeF003reRd_JLLng3JmR7bimJCX5RCiF-j6CzDoHg5tp3B3vxHPNtNjPcK7xyt-K7YMKyyteI9w2pGgRd_O3eOLqzTdVLoWrRV4wwIk5y6bLN5aIPCG540oazIqwjIjq4F9NikkOXw8ujRFW8nFkcw1Gisd2bgEq74J-dvrL_GsUUyXLisyUmW3ub5OSlN_KjkKOrEjntiNCHpzRH-_oFgdvPv7ddHEBpiV0dhgysmdjy5viXzdL-uz35i7EuuEN7Ve_xOZtt4_0IlHevLmqUQeKDhyNYvBN4Fw3I21LOx1CJKQxZ2Q8kOBIVvfNo6psg54TN3FOaBxGmMi8JetqKws77Sydheu7lDIc6kN_dc6VYK4t_s8HTPT_gz) ```python from dataclasses import dataclass, fields from timeit import timeit from operator import attrgetter from statistics import mean, stdev import sys @dataclass(frozen=True, slots=True) class EightFields: foo: str bar: int baz: int spam: list[int] eggs: dict[str, str] x: bool y: bool z: bool data = EightFields( "a", 1, 2, [1, 2, 3, 4, 5], {'a': 'a', 'b': 'b', 'c': 'c'}, True, False, True, ) def current(self): return [getattr(self, f.name) for f in fields(self)] def hardcoded(self): return self.foo, self.bar, self.baz, self.spam, self.eggs, self.x, self.y, self.z def attrgetter_on_the_fly(self): return attrgetter(*map(attrgetter('name'), fields(self)))(self) attrgetter_prebuilt = attrgetter(*map(attrgetter('name'), fields(data))) def attrgetter_prebuilt_wrapped(self): return attrgetter_prebuilt(self) funcs = current, hardcoded, attrgetter_on_the_fly, attrgetter_prebuilt, attrgetter_prebuilt_wrapped for f in funcs: print(f(data)) print() times = {f: [] for f in funcs} def stats(f): ts = [t * 1e9 for t in sorted(times[f])[:5]] return f'{round(mean(ts)):4} ± {round(stdev(ts)):2} ns ' for _ in range(25): for f in funcs: number = 10000 t = timeit(lambda: f(data), number=number) / number times[f].append(t) for f in sorted(funcs, key=stats): print(stats(f), getattr(f, '__name__', 'attrgetter_prebuilt')) print() print(sys.version) ```
sobolevn commented 1 year ago

attrgetter seems like an interesting idea. However, it does not work on ZeroFields:

Traceback (most recent call last):
  File "/Users/sobolev/Desktop/cpython/ex.py", line 18, in <module>
    @dataclass(frozen=True, slots=True)
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sobolev/Desktop/cpython/Lib/dataclasses.py", line 1253, in wrap
    return _process_class(cls, init, repr, eq, order, unsafe_hash,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sobolev/Desktop/cpython/Lib/dataclasses.py", line 1121, in _process_class
    cls = _add_slots(cls, frozen, weakref_slot)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sobolev/Desktop/cpython/Lib/dataclasses.py", line 1226, in _add_slots
    cls.__getstate__ = attrgetter(*map(attrgetter('name'), fields(cls)))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: attrgetter expected 1 argument, got 0

So, we need to avaluate map(attrgetter('name'), fields(cls)) early and convert it to tuple. And have different code paths for empty types and types with fields.

Moreover, simple cls.__getstate__ = attrgetter(*map(attrgetter('name'), fields(cls))) does not work anyway, it is required to use a wrapper. All in all, this is the final version:

from operator import attrgetter
def __getstate__(self):
    if self.__dataclass_attrgetter__ is None:
        return ()
    return self.__dataclass_attrgetter__(self)

if fields(cls):
    cls.__dataclass_attrgetter__ = attrgetter(*map(attrgetter('name'), fields(cls)))
else:
    cls.__dataclass_attrgetter__ = None

cls.__getstate__ = __getstate__

And here are the timings:

Testing ZeroFields
Dump time: 0.182998 sec

Testing OneField
Dump time: 0.186928 sec

Testing EightFields
Dump time: 0.315626 sec

Testing NestedFields
Dump time: 0.564596 sec

With only one generated method (we still have to do something similar to __setstate__) the creation times are:

Testing create_zero
-- before
Result 0.19914910700026667
-- hardcode
Result 0.2348911329972907
-- getstate attrgetter
Result 0.22353017799468944

Testing create_one
-- before
Result 0.22826643000007607
-- hardcode
Result 0.2645908749982482
-- getstate attrgetter
Result 0.2843174829977215

Testing create_eight
-- before
Result 0.3498685700033093
-- hardcode
Result 0.4295154959982028
-- getstate attrgetter
Result 0.4447074749987223

Testing create_nested
-- before
Result 0.41222004499286413
-- hardcode
Result 0.4236956989989267
-- getstate attrgetter
Result 0.31111940600385424

So, as you can see the creation time with only one method is slower that the hardcode patch and dump times are also worth.

Maybe there's something we can improve? But, in the current state is not worth it.

pochmann commented 1 year ago

My first instinct is: "Let's fix attrgetter". Don't you hate it when things refuse to work with empty things...

Second instinct: Don't check every time, use a function instead of None:

from operator import attrgetter
def __getstate__(self):
    return self.__dataclass_attrgetter__(self)

if fields_ := fields(cls):
    cls.__dataclass_attrgetter__ = attrgetter(*map(attrgetter('name'), fields_))
else:
    cls.__dataclass_attrgetter__ = lambda _: ()

cls.__getstate__ = __getstate__

Can you tell why the wrapper is needed?

Testing create_nested -- before Result 0.41222004499286413 -- hardcode Result 0.4236956989989267 -- getstate attrgetter Result 0.31111940600385424

That looks odd. Does that mean that creating the attrgetter took negative time?

sobolevn commented 1 year ago

That looks odd. Does that mean that creating the attrgetter took negative time?

I think I've just messed something up while copy-pasting 🤔 Please, let me double check the results.

pochmann commented 1 year ago

Actually, without the extra attribute (and calling fields(cls) only once):

from operator import attrgetter

if fields_ := fields(cls):
    getter = attrgetter(*map(attrgetter('name'), fields_))
    cls.__getstate__ = lambda self: getter(self)
else:
    cls.__getstate__ = lambda self: ()
carljm commented 1 year ago

Every dataclass must be created; many dataclasses are never pickled at all. So I am not super excited about trading creation time for pickling performance (even though the latter could happen many times per class.) Haven't followed all of the numbers or options here (looks like we don't have reliable numbers for the latest attrgetter option(s) yet?) but IMO for this to be worth it we should have very small impact on dataclass creation time.