pytoolz / toolz

A functional standard library for Python.
http://toolz.readthedocs.org/
Other
4.57k stars 258 forks source link

Dictionary chunking / batching (inverting merge_with) #580

Open BrandonSmithJ opened 4 weeks ago

BrandonSmithJ commented 4 weeks ago

The following seems like it would be a pretty common use case:

import tlz
c = [{'a':1, 'b':2}, {'a':3, 'b':4}, {'a':5, 'b':6}]
d = tlz.merge_with(list, c)
print(d) # {'a': [1, 3, 5], 'b': [2, 4, 6]}

# Now chunk to a certain size
e = chunk_dict(2, d)
print(e) # [{'a':[1,3], 'b':[2,4]}, {'a':[5], 'b':[6]}]

In other words, I first combine a bunch of dictionaries into a single dict with a concatenated list of values, then I want to chunk those dictionaries so that the lists are at most some specified length.

Perhaps even more useful is what I actually want it for: inverting merge_with. That is, I have a list of batched data to put through a model, and I want to change the batch size:

import numpy as np

# Merge batched data together
batch_size3 = [{'a': np.array([1,2,3]), 'b':np.array([4,5,6])}, {'a':np.array([7]), 'b':np.array([8])}]
all_batches = tlz.merge_with(np.hstack, batch_size3)
print(all_batches) # {'a': array([1, 2, 3, 7]), 'b': array([4, 5, 6, 8])}

# Invert merge_with using different batch size
batch_size2 = chunk_dict(2, all_batches)
print(batch_size2) # [{'a': array([1, 2]), 'b': array([4, 5])}, {'a': array([3, 7]), 'b': array([6, 8])}]

Unless I'm missing something, there doesn't seem to be a straight-forward way to do this with toolz (though it seems like functionality it would have). Here's the best solution I've come up with so far:

from itertools import starmap, zip_longest
from typing import Iterator
from math import ceil
from tlz import merge, partition_all

def chunk_dict(n: int, d: dict) -> Iterator[dict]:
    """Chunk a dict of lists into separate dicts with lists of max length `n`.

    Parameters
    ---------
    n : int
        Chunk size, i.e. max length of the new dictionary values.
    d : dict
        Dictionary of iterables to chunk.

    Returns
    -------
    Iterator[dict]
        Dictionaries whose values are now of length at most `n`.

    """
    def chunk(k, v):
        """ Allows slicing numpy arrays so they remain array objects """
        if hasattr(v, '__len__') and hasattr(v, '__getitem__'):
            try:    return ({k: v[i*n:(i+1)*n]} for i in range(ceil(len(v)/n)))
            except: pass # objects may still not support slicing
        return ({k: part} for part in partition_all(n, v))
    return map(merge, zip_longest(*starmap(chunk, d.items()), fillvalue={}))

Is there a better way of doing any of this, and would it be useful to add this type of function to toolz?