Chris-hughes10 / pytorch-accelerated

A lightweight library designed to accelerate the process of training PyTorch models by providing a minimal, but extensible training loop which is flexible enough to handle the majority of use cases, and capable of utilizing different hardware options with no code changes required. Docs: https://pytorch-accelerated.readthedocs.io/en/latest/
Apache License 2.0
176 stars 21 forks source link

Fix remove_padding for higher dimensionality tensors #52

Closed bepuca closed 1 year ago

bepuca commented 1 year ago

The current remove_padding functionality was not working for higher dimensionality tensors (ndim > 2). In fact, the test for 3D was passing incidentally because the first and second dimensions of the tested tensor were the same.

The padding_mask only needs to operate over the batch dimension, which by convention, is the dim=0. That means, it always needs to be one-dimensional when applied to the tensor. The current code only reduces one dimension, which means that for higher than 2 dimensions, the mask has more than one dimension, leading to unexpected and often erroneous results.

The PR fixes these issues and adds tests to validate that now the code works. Previous tests are left in to show there is no regression.