Closed mrocklin closed 4 years ago
I've spent a lot of time lately working on DataFrame serialization with as little overhead as possible. I'm happy to assist as best I can with development to support this use case and benchmarking to support the decision making. I recently did some rudimentary benchmarking on uncompressed numeric data, see slide 24 in
http://www.slideshare.net/wesm/python-data-wrangling-preparing-for-the-future
the benchmarking code (which I can clean up and publish if so desired) looks like:
import pyarrow as pa
import pyarrow.io as io
import pyarrow.ipc as ipc
def write_arrow(df):
batch = pa.RecordBatch.from_pandas(df)
buf = BytesIO()
writer = ipc.ArrowFileWriter(buf, batch.schema)
writer.write_record_batch(batch)
writer.close()
return buf.getvalue()
def read_arrow(buf):
reader = ipc.ArrowFileReader(buf)
return [reader.get_record_batch(i).to_pandas()
for i in range(reader.num_record_batches)]
Here, on the read path, this is all zero-copy up until converting to pandas.DataFrame, which introduces some overhead. This overhead could be reduced through the slightly messy process of deserializing directly into a 2D ndarray and then constructing pandas's internal BlockManager data structure. this will become much more favorable with pandas 2.0 (can be fully zero-copy if we use bitmaps for all null handling, see discussion in https://github.com/pandas-dev/pandas2/issues/46).
For compressible data you can obviously add buffer-level compression -- see https://issues.apache.org/jira/browse/ARROW-300. So in other words, if you have an array with a null bitmap buffer and a contiguous buffer, these buffers would be compressed independently on going into the IPC blob.
Dictionary encoding isn't yet implemented, but it's provided for in the IPC metadata.
Maybe this is pie in the sky, but at some point it may be in dask's interests to adopt a language independent serialization protocol so there is the option to run task worker code written in C++, Julia, or other languages.
Hi @wesm . Thanks for the read and write functions. These should be easy to benchmark and easy to incorporate as an optional dependency (hooray for Arrow being on conda-forge by the way). @jreback any interest in trying these out in the wild? You have some communication heavy computations.
Zero copy sounds nice. So too does compression on different components. It'll be nice to try some workloads in the wild though to see how much impact these have. That might help us to prioritize or deprioritize them. I have not yet run into workloads that are limited by single memory copies. Network bandwidth is more commonly a larger issue for me, at least on Amazon's hardware.
Maybe this is pie in the sky, but at some point it may be in dask's interests to adopt a language independent serialization protocol so there is the option to run task worker code written in C++, Julia, or other languages.
It's decently possible and there is some slow momentum towards that direction. The custom serialization system is a step in that direction. There are other users clamoring for destination-specific serialization (e.g. gpu-gpu transfers) that would also benefit from this. Presumably you would say "OK, this data of type T has to move from worker A to B, what is the best serialization that we have for this case?" If there were a use case today the serialization work could happen in about a week. I put the cost of making a decent Julia/R/whatever worker/client at "a couple of weeks" by someone familiar with the networking stack in that language. See https://github.com/dask/distributed/issues/586
sure, i'll give this a go. IIRC you are going to have this settable via dask.set_options()
?
(or someway completely hacky ok), just so I can go back and forth to try things out?
though since on win-64.... will wait for @wesm to kindly make packages available for pyarrow
:>
Ah, forgot about the Windows thing. You should wait to try this until I give you the go ahead. I'm fairly comfortable putting Arrow into an optional plugin after #606 gets merged. There will be a few of these for other formats as well.
Initial timings comparing Pickle and Arrow. Corrections or counter-benchmarks welcome. Based on this very narrow benchmark custom serialization + blosc is winning out.
In [1]: import pyarrow as pa
...: import pyarrow.io as io
...: import pyarrow.ipc as ipc
...:
...: def write_arrow(df):
...: batch = pa.RecordBatch.from_pandas(df)
...:
...: buf = BytesIO()
...: writer = ipc.ArrowFileWriter(buf, batch.schema)
...: writer.write_record_batch(batch)
...: writer.close()
...: return buf.getvalue()
...:
...:
...: def read_arrow(buf):
...: reader = ipc.ArrowFileReader(buf)
...: return [reader.get_record_batch(i).to_pandas()
...: for i in range(reader.num_record_batches)]
...:
In [2]: from io import BytesIO
In [3]: import string
In [4]: import numpy as np
In [5]: import pandas as pd
In [6]: n = 1000000
In [7]: df = pd.DataFrame({'x': np.random.randint(0, 10000, size=n, dtype='i4'), 'y': np.random.random(n),
...: 't': np.random.choice(list(string.ascii_lowercase), size=n)})
In [8]: import pickle
In [9]: %time len(write_arrow(df))
CPU times: user 120 ms, sys: 4 ms, total: 124 ms
Wall time: 121 ms
Out[9]: 17000630
In [10]: %time len(pickle.dumps(df))
CPU times: user 36 ms, sys: 8 ms, total: 44 ms
Wall time: 42.7 ms
Out[10]: 14003232
In [11]: %time len(write_arrow(df[['x']]))
CPU times: user 4 ms, sys: 0 ns, total: 4 ms
Wall time: 3.4 ms
Out[11]: 4000334
In [12]: %time len(pickle.dumps(df[['x']]))
CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 2.85 ms
Out[12]: 4000781
In [13]: import blosc
In [14]: x = df.x.values # There is an easier way to do this
In [15]: %time len(blosc.compress(x.data, typesize=x.dtype.itemsize, cname='lz4', clevel=5))
CPU times: user 8 ms, sys: 0 ns, total: 8 ms
Wall time: 3.72 ms
Out[15]: 2020787
In [16]: %time len(pickle.dumps(df[['t']]))
CPU times: user 44 ms, sys: 0 ns, total: 44 ms
Wall time: 45.5 ms
Out[16]: 2002878
In [17]: %time len(write_arrow(df[['t']]))
CPU times: user 96 ms, sys: 0 ns, total: 96 ms
Wall time: 95 ms
Out[17]: 5000406
How did you install pyarrow? I believe you are using a debug build (the conda-forge artifact is a debug build, I can fix this to allow you to do a more apples-to-apples comparison)
conda create -n arrow-test -c conda-forge python=3 pyarrow numpy pandas ipython -y
I may have pip installed some thing in between
(arrow-test) mrocklin@carbon:~$ conda list -e
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
Using Anaconda Cloud api site https://api.anaconda.org
arrow-cpp=0.1.post=1
backports.shutil_get_terminal_size=1.0.0=py35_0
blas=1.1=openblas
ca-certificates=2016.8.31=0
certifi=2016.8.31=py35_0
decorator=4.0.10=py35_0
ipython=5.1.0=py35_1
ipython_genutils=0.1.0=py35_0
ncurses=5.9=9
numpy=1.11.2=py35_blas_openblas_200
openblas=0.2.18=5
openssl=1.0.2h=2
pandas=0.19.0=np111py35_0
parquet-cpp=0.1.pre=3
pexpect=4.2.1=py35_0
pickleshare=0.7.3=py35_0
pip=8.1.2=py35_0
prompt_toolkit=1.0.8=py35_0
ptyprocess=0.5.1=py35_0
pyarrow=0.1.post=0
pygments=2.1.3=py35_1
python=3.5.2=2
python-dateutil=2.5.3=py35_0
pytz=2016.7=py35_0
readline=6.2=0
setuptools=26.1.1=py35_0
simplegeneric=0.8.1=py35_0
six=1.10.0=py35_1
sqlite=3.13.0=1
tk=8.5.19=0
traitlets=4.3.0=py35_0
wcwidth=0.1.7=py35_0
wheel=0.29.0=py35_0
xz=5.2.2=0
zlib=1.2.8=3
libgfortran=3.0.0=1
lz4=0.8.2=py35_0
Yes, you're on a debug build of pyarrow. I'll let you know when there's a new build available built with optimizations turned on (this build is gcc -O0
)
Digging into a memory leak bug (ARROW-362) but will get a release build up on conda-forge after that's fixed
PyArrow release builds are up. You ideally should to disable multithreading in blosc to make the benchmarks comparable -- we could add multithreaded writes to Arrow, too, and get similar speedups
tmp = df[['x']] # this is a copy -- should not be included in benchmark
import blosc
blosc.set_nthreads(1) # multithreading is not apples-to-apples
x = df.x.values
df_t = df[['t']] # this is a copy -- should not be included in benchmark
%timeit len(write_arrow(df))
%timeit len(pickle.dumps(df))
%timeit len(write_arrow(tmp))
%timeit len(pickle.dumps(tmp))
%timeit len(blosc.compress(x.data, typesize=x.dtype.itemsize, cname='lz4', clevel=5))
%time len(pickle.dumps(df_t))
%time len(write_arrow(df_t))
In [27]: %timeit len(write_arrow(df))
...:
10 loops, best of 3: 27.3 ms per loop
In [28]: %timeit len(pickle.dumps(df))
...:
10 loops, best of 3: 24.4 ms per loop
In [29]: %timeit len(pickle.dumps(tmp))
...:
The slowest run took 5.50 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 626 µs per loop
In [30]: %timeit len(write_arrow(tmp))
...:
The slowest run took 6.02 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 689 µs per loop
In [31]: %timeit len(pickle.dumps(tmp))
^[[41;1R
1000 loops, best of 3: 627 µs per loop
In [34]: %timeit len(blosc.compress(x.data, typesize=x.dtype.itemsize, cname='lz4', clevel=5))
...:
The slowest run took 5.77 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 683 µs per loop
In [36]: %timeit len(pickle.dumps(df_t))
10 loops, best of 3: 26.1 ms per loop
In [37]: %timeit len(write_arrow(df_t))
100 loops, best of 3: 17.7 ms per loop
if you change the characteristics of the string data, you get different results:
In [96]: paste
from pandas.util.testing import rands
string_choices = [rands(10) for i in range(1000)]
df = pd.DataFrame({'x': np.random.randint(0, 10000, size=n, dtype='i4'), 'y': np.random.random(n),
't': np.random.choice(string_choices, size=n)})
## -- End pasted text --
In [97]: %timeit len(write_arrow(df))
10 loops, best of 3: 47.1 ms per loop
In [98]: %timeit len(pickle.dumps(df))
10 loops, best of 3: 164 ms per loop
In [99]: df_t = df[['t']] # this is a copy -- should not be included in benchmark
In [100]: %timeit len(pickle.dumps(df_t))
10 loops, best of 3: 152 ms per loop
In [101]: %timeit len(write_arrow(df_t))
10 loops, best of 3: 40.1 ms per loop
there's some performance loss on the Arrow read path right now because we aren't constructing the precise pandas BlockManager. If we wanted to hyperoptimize for the pandas 1.x memory layout this would be a nice project for someone to tackle (multithreaded blockmanager read/write)
note you can also go through arrow's native IO subsystem and avoid extra pybytes interactions -- this lets you interact with memory in C/C++ with zero copy
def write_arrow(df):
batch = pa.RecordBatch.from_pandas(df)
buf = io.InMemoryOutputStream()
writer = ipc.ArrowFileWriter(buf, batch.schema)
writer.write_record_batch(batch)
writer.close()
return buf.get_result()
YMMV, worth double checking my work in case I made some mistakes
OK, there is a trivial implementation using Arrow on the Dask side here: https://github.com/dask/distributed/pull/643
Once questions there get resolved we should have enough in place to start doing more integrative benchmarks.
In my mind there are two approaches, Blosc-and-custom-code or Arrow. Each has pros and cons.
The Blosc solution can be implemented today and is likely to be near optimal in speed, at least for numeric data. The main cons here are that we need to write code around the block manager etc. and we need to maintain this going forward. We're likely to miss things like new dtypes, categorical index, etc.. The good news here is that we've already done a lot of this a couple times before and can probably steal ideas from previous implementations in systems like partd.
The Arrow solution has two main things going for it:
However in the short term it is missing some things like per-column or dtype-specific compression, comprehensive support for the Pandas abstraction (see issues in #643, I'd love to be wrong about this) etc..
I would like to write up and tune some performance benchmarks for dask.dataframe sometime in the next month. Serialization and communication costs do contribute significantly in shuffling so this is likely to end up on my critical path soon-ish.
It seems reasonable to have as many options available as possible (assuming the development burden is not too much), and as long as we have reproducible benchmarks to evaluate performance that can help drive performance engineering work. The Arrow<->Pandas conversion has a number of immaturities so it's not an turnkey solution at the moment -- it is an accessible codebase, though, so I invite others to get involved and I can help review patches and where help reach consensus w/ the other Arrow devs where design changes are needed (e.g. adding compression options).
@mrocklin I think its worth trying a hybrid approach here.
wholesale pickling of the frame
In [10]: %time len(pickle.dumps(df))
CPU times: user 41.6 ms, sys: 8.45 ms, total: 50 ms
Wall time: 50 ms
Out[10]: 14003232
pickle in a dict-of-columns, about the same
In [7]: %time len(pickle.dumps({c: col for c, col in df.iteritems()}))
CPU times: user 41.9 ms, sys: 14.4 ms, total: 56.3 ms
Wall time: 56.3 ms
Out[7]: 14003220
but now can compress per-column
In [8]: %time len({c: blosc.compress(col.values, typesize=col.dtype.itemsize, cname='lz4', clevel=5) for c, col in df.iteritems()})
CPU times: user 14.4 ms, sys: 3.21 ms, total: 17.6 ms
Wall time: 6.92 ms
Out[8]: 3
In [11]: %time len(pickle.dumps({c: blosc.compress(col.values, typesize=col.dtype.itemsize, cname='lz4', clevel=5) for c, col in df.iteritems()}))
CPU times: user 18 ms, sys: 7.13 ms, total: 25.1 ms
Wall time: 13.7 ms
Out[11]: 12608516
points of note:
pd.concat
, a single copy) on the receiving side, would require that you also send the column order, so a bit more complexI would like to avoid pickle if possible. It uses a couple of needless memory copies. Also if we don't compress (which happens sometimes) it's nice to be able to just pass along the memoryview directly. We could still do this when operating on columns (sliced memoryiews are just views).
Why would compressing columns individually be better than compressing several at once? I think that most of the fast compressors are purely local.
If I put up a minimal draft of this would you have time to review/suggest tests?
Why would compressing columns individually be better than compressing several at once? I think that most of the fast compressors are purely local.
This is a big assumption that only sometimes is True. IOW, you assume that 2 integer (or 2 float or whatever) are compressable just because they happen to be the same dtype. If they represent different data then they may or may not compress well as a 2-d array (not that actually know if blosc is smart about this, if it is, then what I am saying may be False!).
I think its safer/better to simply compress single columns (e.g. this is what a column-major storage system does, e.g. redshift).
I would like to avoid pickle if possible. It uses a couple of needless memory copies. Also if we don't compress (which happens sometimes) it's nice to be able to just pass along the memoryview directly. We could still do this when operating on columns (sliced memoryiews are just views).
yes I don't think you actually need pickle at all. (assume that you only have strings for object dtypes, FYI, you can do a pd.lib.infer_dtype
to check this). A byte protocol is prob best.
If I put up a minimal draft of this would you have time to review/suggest tests?
sure
This is a big assumption that only sometimes is True. IOW, you assume that 2 integer (or 2 float or whatever) are compressable just because they happen to be the same dtype. If they represent different data then they may or may not compress well as a 2-d array (not that actually know if blosc is smart about this, if it is, then what I am saying may be False!).
@FrancescAlted can you shed some light?
Question for @jreback: The hard part for me on this problem is to create a function that takes a dataframe and a column and produces a numpy array + some metadata without performing a copy? The metadata would be for timestamp, cateogorical columns, etc.. I suspect that it will take me a while to make a robust version of this function that is aware of all of the Pandas corner cases.
so I'll point you to this: https://github.com/pandas-dev/pandas/blob/master/pandas/io/packers.py#L258 (see convert
and unconvert
). These are used in msgpack, but no reason not to use them (at least for testing) here.
Actually the code that exists in msgpack is pretty robust for serializing/de-serialzing to bytes (which are then sent to msgpack), but you can almost copy these (again for testing). Though they do block-by-block (can easily be changed to do column-by-column instead).
As a bonus, @llllllllll has done work to make the uncompression zero-copy (if possible, as its not always possible)
@jreback Yes, my experience is also that the compression by columns normally brings the best results in large tables. The reason for that is that the shuffle (or bitshuffle) filter normally does a much better job at putting zeros together. When you have complex dtypes, shuffle can still do a good job, but it is slower (i.e. it does not use the SIMD instructions in CPUs), and in addition, Blosc has a limit (mainly for performance reasons) for applying the shuffle filter only for dtypes < 256 bytes.
Regarding serializing/deserializing, have you guys ever tried Google Protocol Buffers? My experience is that they are really fast, specially when used in combination with streamed gRPC. The disadvantage is that both sides need to know the schema in advance, but perhaps this schema can be made flexible enough to transmit buffers of compressed chunks with the actual dtype encoded inside.
@FrancescAlted the protobuf serialize/deserialize steps are fairly expensive. I have had a hard time breaking through the 500 MB/s range on in-memory serialized protobufs. By comparison, I've been able to get 10x or better performance through zero-copy in-memory columnar: http://wesmckinney.com/blog/arrow-streaming-columnar/
Oh wow, that's pretty cool :)
If we add ZSTD/LZ4/Snappy compression steps, we can increase throughput on the wire depending on how fast our network is. I am more than happy to help if these tools are useful
I started looking at Pandas serialization again recently. I'm curious about the near-term expected state of Arrow in the following regards:
@wesm do you have thoughts on either of these two issues? Is mostly-comprehensive coverage of the pandas dataframe abstraction in scope? What are your thoughts on optional per-column compression? Dask tends to compress a small sample of each frame to determine if it should or should not compress the entire dataset.
Yes, comprehensive pandas coverage is definitely in scope. If folks here can't help with the development, you can surely help with requirements gathering and creating JIRAs. We can attach column metadata to indicate that a particular field is the index (we should generate some unique identifier to give the index "column" a unique name)
On per-column compression: all-or-nothing would be simpler (on a per-batch basis), but if per-column has a lot of benefits, then we could discuss adding the appropriate metadata to support that. Implementing it for Python/C++ isn't especially difficult.
I recently added a small set of tests here that I intend to increase to try to eventually define "comprehensive". We did this for NumPy and it was fairly effective at flushing out issues (or being a repository for new arrays as issues arose). I would be very happy to extend that or find a way to crowdsource it.
For compression there are two topics:
For NumPy communication (which we've spent a bit more time optimizing) I've found that testing the effectiveness of compression on a small sample before compressing the whole has been useful. Short term Dask could also do this itself after Arrow passes data off.
Given that we don't always want to compress I assume that per-column compression would be useful, but I don't have any concrete experience here.
Currently we just pickle a dataframe, do the optional compression thing, and ship it down a wire, so it isn't hard to beat the current state.
I think it would make sense to create something like: pandas.tools.serial
which can house a DataFrame->bytes and reverse, so maybe .to_bytes(flavor=, compression=)
and .from_bytes(flavor=, compression=)
with flavor='pickle'
being the default.
we could even have optional dep on pyarrow (and handle fallbacks and such).
@jreback is there a Pandas issue for this? This is the sort of thing that I would love to have in Pandas rather than in Dask.
Also relevant links:
import io
import pandas as pd
import pyarrow as pa
df = pd.DataFrame({"a": [1, 2]})
batch = pa.RecordBatch.from_pandas(df)
sink = io.BytesIO()
writer = pa.StreamWriter(sink, batch.schema)
writer.write_batch(batch)
writer.close()
bytebuffer = sink.getvalue()
reader = pa.StreamReader(pa.BufferReader(bytebuffer))
batch_read = reader.get_next_batch()
df_read = batch_read.to_pandas()
On in-memory data you will get slightly better performance (and no GIL issues -- with BytesIO
it calls back into Python and must acquire GIL) if you use Arrow's built-in stream and buffer objects, so
pyarrow.InMemoryOutputStream()
instead of io.BytesIO()
I'm happy to add some convenience functions to Arrow. Opening some JIRAs now. Patches welcome also
To be clear there are three opportunities for GIL issues:
Shuffle-like computations on numpy arrays are currently much faster than they are on Pandas dataframes, so there is definitely a fair bit of performance to squeeze out here.
@mrocklin and of course happy to add a convenience function directly in pandas, now that pyarrow
2.0 is out (and in conda) I think that's doable.
@jreback I would advise adding such convenience functions in Arrow so we don't have to set up integration tests inside pandas just yet
happy either way :> though I wouldn't find this a big burden on the current pandas code (nor testing).
Literally this would be a direct wrapper to call a single arrow function.
Yeah, totally. When ARROW-596/597 are done (let's do this month), they'll go out in the 0.3 release (this month). Probably will update conda artifacts before then
I decided to pick this up again yesterday; sorry about the new JIRAs Wes :). My work is at https://github.com/dask/distributed/compare/master...TomAugspurger:arrow-serialization
Heres a notebook with some benchmarks at http://nbviewer.jupyter.org/gist/TomAugspurger/ab01751275b8f5262dabc7fd07a0f19f/serialization.ipynb (scroll down for the plot)
Arrow is pretty similar to pickle, other than object
(strings) in which case it's quite a bit faster. Interestingly, arrow was slower to serialize datetime64[ns]. Perhaps that could be optimized.
An important is that pickle can use pandas' RangeIndex to avoid serializing the entire index if possible. With a RangeIndex
pickle outperformed arrow in most cases (other than objects) since it had less work to do. I opened https://issues.apache.org/jira/browse/ARROW-1593 about a simple solution dask could use. We could detect that we have a RangeIndex
when serializing and pass the start
, stop
, step
in our header and just not use preserve_index
(once that's available in pa.serialize_pandas
).
All the tests pass, other than the PeriodIndex
test. This type isn't implemented in Arrow (though I think user-defined logical types are on the roadmap). The implementation is pretty clean. Once the metadata issues are fixed on Arrow's side, it'll be even simpler. I'm not concerned about maintenance burden.
I think using Arrow for serialization would be good to add as an option, and probably even the default in a future release..
No need to apologize for new issues, only way to make the software better! We should try to get an ASV setup rolling so that we can have microbenchmarks for all these cases to put some targets on the wall to optimize, then we can make sure that we don't have any perf regressions going forward. Maybe we can add these ASV benchmarks to the nightly runs on the pandas box at some point.
I suspect that the perf difference in float64 data largely has to do with the null bitmap construction -- pickle writes the internal blocks unmodified. This can be partially mitigated through parallel processing when you have many columns (https://issues.apache.org/jira/browse/ARROW-1594)
Not sure if you're using pickle with NumPy arrays anywhere, but if so I recommend you also take a look at using pyarrow.serialize
and pyarrow.deserialize
as an alternative for moving around ndarrays (or lists or dicts of ndarrays) in dask:
In [16]: import pyarrow as pa
In [17]: import numpy as np
In [18]: import pickle
In [19]: arr = np.random.randn(1000, 1000)
In [20]: %timeit rt = pickle.loads(pickle.dumps(arr))
2.85 ms ± 155 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [21]: %timeit rt = pa.deserialize(pa.serialize(arr).to_buffer())
651 µs ± 30.1 mus per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Might seem crazy to beat pickle by 4x with an ndarray, but avoiding extra copies will do that
We should try to get an ASV setup rolling so that we can have microbenchmarks for all these cases to put some targets on the wall to optimize, then we can make sure that we don't have any perf regressions going forward. Maybe we can add these ASV benchmarks to the nightly runs on the pandas box at some point.
I'm happy to do that. I'll also update the benchmarks to have versions with and without nulls.
I believe the numpy code is also using pickle. I'll benchmark that as well.
For NumPy we're actually aiming for zero-copy. We serialize the metadata and then just pass down the .data
memoryview unmodified (most of the time). People using dask.array tend to be particularly performance conscious; the current system I'm working on has 24Gb/s interconnect. Our current bottleneck is actually Tornado, which does a couple of copies internally.
You might be interested in looking at our test suite for numpy arrays here: https://github.com/dask/distributed/blob/master/distributed/protocol/tests/test_numpy.py#L30-L65
A few corner cases have bit us in the past that are now recorded there. They might be of general value.
I believe the numpy code is also using pickle. I'll benchmark that as well.
The numpy code is not using pickle.
The numpy code is not using pickle.
Ah, good to know (I just looked at the fallback for object types).
@mrocklin cool, that's helpful, thanks. I opened https://issues.apache.org/jira/browse/ARROW-1596 so that we can expand the test suite based on this
Internally pyarrow.deserialize
is using the C API equivalent np.frombuffer
so I would expect the performance to be the same then
Adding a couple of thoughts here:
First if Pandas objects support pickle protocol 5, which it sounds like they may already do ( https://github.com/pandas-dev/pandas/issues/34244 ), and Dask supports pickle protocol 5, which PR ( https://github.com/dask/distributed/pull/3784 ) implements, then it should be possible to get efficient serialization of Pandas objects just using pickle protocol 5.
Second if Pandas objects simply consist of other objects that are already serializable with Dask (like NumPy arrays), then it should be possible to traverse objects and serialize everything we can. In fact this is how we added SciPy sparse matrix serialization recently.
cc @TomAugspurger (for awareness)
First if Pandas objects support pickle protocol 5, which it sounds like they may already do ( pandas-dev/pandas#34244 ), and Dask supports pickle protocol 5, which PR ( #3784 ) implements, then it should be possible to get efficient serialization of Pandas objects just using pickle protocol 5.
FYI this is now in Distributed 2.17.0+. Would be interesting to see what sort of mileage people get out of this 🙂
As of Dask + Distributed 2.21.0, we now support pickle protocol 5 on older versions of Python with the pickle5
backport package. This extracts the NumPy arrays used under-the-hood and preserves their type, which should allow for efficient compression. Here's a quick example:
In [1]: import numpy
...: import pandas
...:
...: from distributed.protocol import serialize, deserialize
In [2]: df = pandas.DataFrame({
...: "a": [1, 2, 3],
...: "b": [0.0, 0.7, 1.3]
...: })
...: df
Out[2]:
a b
0 1 0.0
1 2 0.7
2 3 1.3
In [3]: header, frames = serialize(df)
In [4]: list(map(numpy.asarray, frames[1:]))
Out[4]: [array([[0. , 0.7, 1.3]]), array([[1, 2, 3]])]
In [5]: df2 = deserialize(header, frames); df2
Out[5]:
a b
0 1 0.0
1 2 0.7
2 3 1.3
Closing as this seems to be resolved with pickle protocol 5, but please feel free to reopen if I've missed something.
We need an efficient way to serialize Pandas Dataframes. As of #606 we can now customize this beyond just pickle. There are a number of different data regimes here including pure numeric data, very compressible numeric data like time series, text data with repeats, categoricals, long text data etc.. We care both about fast encoding and about fast compression for larger results. We want something that has minimal overhead on small dataframes (important for shuffling).
Several options come to mind:
@jreback @wesm @shoyer
It would be good to make this decision with a benchmark in hand. It would be good both to get people's opinions on solutions we should consider as well as some benchmarks that are representative of data that they care about.