uber / petastorm

Petastorm library enables single machine or distributed training and evaluation of deep learning models from datasets in Apache Parquet format. It supports ML frameworks such as Tensorflow, Pytorch, and PySpark and can be used from pure Python code.
Apache License 2.0
1.8k stars 284 forks source link

Using threads taking longer than dummy #650

Closed Jomonsugi closed 3 years ago

Jomonsugi commented 3 years ago

This issue looked like it might shed light on my issue, but it seems I'm missing something more fundamental. When I run this script below to simply iterate through the BatchedDataLoader, using the thread option actually takes longer than using the main worker via dummy and no matter what, it seems that more workers results in a longer processing time. I'm running this on a machine with 16 vCPU and 64 Gib Memory. The memory stays stable through the process, but I notice CPU usage spikes throughout these iterations. It seems like no matter how long an iteration is taking though, using multiple workers should take less time overall, right? The total size of this dataset I'm testing on is small, around ~1.4 GiB, saved in parquet format in s3, ~1.5M rows and I'm only running 100 iterations with batch size 100. Maybe the script is not working as intended?

Script:

pools = ['dummy', 'thread']
workers = [1, 4, 8]
for i in product(pools, workers):
    print(i)

# the trsfm_spec is doing a few different transformations on various columns given the pandas dataframe
for pool_type, workers_count in product(pools, workers):
    with make_batch_reader(s3_path, 
                               workers_count=workers_count,
                               transform_spec=trsfm_spec,
                               schema_fields=cols,
                               num_epochs=10,
                               reader_pool_type = pool_type) as reader:

        loader = BatchedDataLoader(reader, batch_size=100)
        loader_iter = iter(loader)

        time_sum = 0
        batch_size_sum = 0
        batches = 100
        print('--')
        print(f'workers_count: {workers_count} and pool_type: {pool_type}')
        loop_start = time.time()
        for batch_idx in range(batches):
            start = time.time()
            batch = next(loader_iter)
            end = time.time()

            t = end - start       
            time_sum += t
            batch_size = sys.getsizeof(batch)
            if not batch_idx in [0,1]:
                batch_size_sum += batch_size
        loop_end = time.time()
        print(f'time sum: {time_sum}')
        print(f'average time to process batch: {time_sum / (batches-2)}')
        print(f'loop time: {loop_end-loop_start}')
        print(f'average batch size: {batch_size_sum / batch_idx}')

Output:

('dummy', 1)
('dummy', 4)
('dummy', 8)
('thread', 1)
('thread', 4)
('thread', 8)
--
workers_count: 1 and pool_type: dummy
time sum: 2.927597761154175
average time to process batch: 0.029873446542389537
loop time: 2.927725315093994
average batch size: 1172.040404040404
--
workers_count: 4 and pool_type: dummy
time sum: 4.17024040222168
average time to process batch: 0.04255347349205796
loop time: 4.1703784465789795
average batch size: 1172.040404040404
--
workers_count: 8 and pool_type: dummy
time sum: 4.148790121078491
average time to process batch: 0.042334593072229504
loop time: 4.1489198207855225
average batch size: 1172.040404040404
--
workers_count: 1 and pool_type: thread
time sum: 4.524890422821045
average time to process batch: 0.04617235125327597
loop time: 4.525023937225342
average batch size: 1172.040404040404
--
workers_count: 4 and pool_type: thread
time sum: 8.192336320877075
average time to process batch: 0.08359526858037832
loop time: 8.192479372024536
average batch size: 1172.040404040404
--
workers_count: 8 and pool_type: thread
time sum: 11.62452483177185
average time to process batch: 0.11861760032420256
loop time: 11.624683380126953
average batch size: 1172.040404040404
selitvin commented 3 years ago

Hi. Thank you for the detailed reproduction snippet. It helped me (I hope) understand what is happening in your case.

Here are couple of ideas:

  1. when 'dummy' pool type is used, workers_count argument has no effect and all loading happens on the main thread.
  2. You set batch-size to 100, however, minimal atomic read unit is a row-group, and would typically default to more then hundred megabytes. This means that 100 batches of 100 rows - you would need only a single row-group to finish your measurements loop and the parallelism (i.e. workers_count>1) is not expected to benefit your code. Actually, the opposite is more likely, as getting first sample is going to be slower with workers_count>1, because multiple threads are competing on your network, so time-to-first-sample is likely to be higher.
  3. The implementation of BatchedDataLoader appears to be very inefficient when batch_size<<rows_in_a_rowgroup. However, if you introduce shuffling, a different more efficient implementation is used. Try BatchedDataLoader(..., shuffling_queue_capacity=100000) (or whatever size of a shuffling queue you like), and you'll see much faster performance. I'll take a look at fixing the performance of the BatchedDataLoader for non-shuffling case. It can probably be fixed reasonably easy.
  4. Also, you can try next(reader) instead of next(loader_iter) (you'll be getting entire row-groups as the result, so better reduce the value of batches in your code). You would be able to observe the expected effect of workers_count.
Jomonsugi commented 3 years ago

Thank you for your response.

In response to your ideas above:

  1. I did know this, but wanted to confirm expected behavior. I appreciate the clarification though.

  2. This is the fundamental concept I was missing. I did not understand that each worker would process a row group and that the benefit would only come if more than one row groups was being accessed given the amount of batches iterated through.

  3. I did play around with the argument shuffling_queue_capacity. To be honest I found my results to be "fiddly". Depending on the size of data, the partition sizes, the row group sizes, and even the transform_spec I found adding this argument to both add and decrease latency. I humbly suggest a more detailed overview of how to increase throughput based on all these variables in conjunction with shuffling_queue_capacity in your documentation or at least some pointers to help guide user experimentation.

  4. As noted above, I missed this concept. After iterating on the row group instead of batches I started to see results that made more sense. 1 worker was always less efficient than > 1 workers, although I did see using all the workers on my instance consistently less efficient than using less than all workers. Not a big deal as I can access instances with many more workers if need be.

('thread', 4)
('thread', 3)
('thread', 2)
('thread', 1)
--
workers_count: 4 and pool_type: thread
batches: 10
loop time: 9.915568590164185
average time to process batch: 1.1017176575130887
average batch size: 64.0
--
workers_count: 3 and pool_type: thread
batches: 10
loop time: 7.1100451946258545
average time to process batch: 0.7899945047166612
average batch size: 64.0
--
workers_count: 2 and pool_type: thread
batches: 10
loop time: 7.345884323120117
average time to process batch: 0.8162005477481418
average batch size: 64.0
--
workers_count: 1 and pool_type: thread
batches: 10
loop time: 11.899104833602905
average time to process batch: 1.322114732530382
average batch size: 64.0

Note that I vastly slimmed down what was happening in the TransformSpec in my test here which accounts for the lower latency compared to my original example where even one batch would take nearly half a second or more. The latency on loading the data ended up being the red herring. The amount of transformation that is needed to meet the demands of petastorm's BatchedDataLoader (pandas input/output, no access to pytorch collate_fn for a chance at decreasing latency) to work as a solution for input to our particular pytorch model is unwieldy. We just have too many columns where vectors of unequal size need to be transformed after being loaded from parquet, preventing a vectorized transformation and thus an efficient use of the BatchedDataLoader process it seems.

selitvin commented 3 years ago

Do you think having an option for a user to supply their own collate function would be helpful in your case? This is something that was brought up in #647 and I will try to address in the following weeks.

Jomonsugi commented 3 years ago

Maybe. In my case I have arrays that are of unequal size that need to be 1) mapped to indexes and 2) padded with 0's. For padding, pytorch has a method to deal with tensors of uneven size via collate_fn at the batch level through their DataLoader object that would definitely be faster than what I'm doing right now with a .apply through pandas. For the case of indexing, if I can find a solution through pytorch that is quicker than the python dictionary lookup I am using to map values to indexes, then the answer again would be yes.

In my opinion, the more that can be done with pytorch, the better. Full control of the batch of pytorch tensors produced by BatchedDataLoader before being fed into the model could be helpful to many other use cases as well.

selitvin commented 3 years ago

Agreed. We'll try moving in this direction. Thank you for your input!