Chris-hughes10 / pytorch-accelerated

A lightweight library designed to accelerate the process of training PyTorch models by providing a minimal, but extensible training loop which is flexible enough to handle the majority of use cases, and capable of utilizing different hardware options with no code changes required. Docs: https://pytorch-accelerated.readthedocs.io/en/latest/
Apache License 2.0
174 stars 21 forks source link

Improve reporting of issues with batch size on distributed evaluate #53

Closed bepuca closed 1 year ago

bepuca commented 1 year ago

When running a Trainer.evaluate on a multi GPU configuration, the interaction between the batch size and the size of the eval dataset may cause either:

  1. One process have a different number of samples than the rest on the last batch.
  2. One or more processes to have 0 samples on the last batch.

In most (if not all) cases, at some point during or after an evaluate run, there will be a call to Trainer.gather to synchronize the results across processes. If any of the 2 conditions is mentioned, that will freeze the run without any kind of reporting, which can be extremely frustrating and hard to debug if the user has not encountered it before.

Number 1. can be solved by ensuring a padding_value to Trainer.gather. Number 2. can only be solved by setting an appropriate batch size that fixes the issue.

This PR adds a warning for number 1. so that users are prompted to ensure they make the right call and an error raise for number 2. to force the user to pick an adequate batch size.

There is one point open to discussion though:

Users that do not need to gather (because they save the outputs in a distributed fashion or because they do not need to persist the outputs, for instance) will never be hit by these cases. This is ok for number 1, because only a warning is raised. The issue is the way this is defined now, number 2. would enforce users to change to a different batch size (potentially a suboptimal one) for something they do not care about.

Then the question becomes: Do we want to just warn or do we raise an error for case 2?

Chris-hughes10 commented 1 year ago

I think a warning for point 2 should be fine