ml-explore / mlx-examples

Examples in the MLX framework
MIT License
5.5k stars 791 forks source link

iterate_batches in mlx_lm's Lora trainer is discarding the remainder dataset items (modulo batch size) #843

Open chimezie opened 1 week ago

chimezie commented 1 week ago

The current implementation of iterate_batches produces batches for all but the remaining N items in the dataset (where N is less than the batch size). So, if your dataset size is a multiple of the batch size, you will eventually train on every item in the dataset. However, if it is not, the remainder will never be included in what is trained, regardless of how many iterations you set.

Below is very minimal test case that replicates this (it generates batches of a total size of at most 46 items without finding the remaining data before stopping):

from unittest.mock import MagicMock
from mlx_lm.tuner.trainer import iterate_batches
import mlx.core as mx

class TestBatching(unittest.TestCase):
    def test_batch_remainder(self):
        tokenizer = MagicMock()
        tokenizer.eos_token_id = 9
        tokenizer.encode = lambda x: [1,] if x == "foo" else [7,]

        #The dataset is 20 "foo"s and 3 "bar"s (the remainder) 
        dataset = ["foo"] * 20 + ["bar", "bar", "bar"]
        batch_size = 10
        found_remainders = False
        for idx, (batch_in, _, _) in enumerate(iterate_batches(dataset, tokenizer, batch_size, max_seq_length=2048,
                                                               train=True)):
            found_remainders = mx.any(batch_in == mx.full(batch_in.shape, 7))
            if found_remainders or ((idx + 1) * batch_size) > len(dataset) * 2:
                break
        self.assertTrue(found_remainders.item())
awni commented 1 week ago

The current implementation of iterate_batches produces batches for all but the remaining N items in the dataset (where N is less than the batch size). So, if your dataset size is a multiple of the batch size, you will eventually train on every item in the dataset. However, if it is not, the remainder will never be included in what is trained, regardless of how many iterations you set.

Yes, that's how it works now. Probably we should just change it to not drop the last batch even if it is not the same size as the others. I will make this as an enhancement.