UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.19k stars 2.47k forks source link

IterableDataset with v3.0.0 #2698

Closed david-waterworth closed 1 month ago

david-waterworth commented 5 months ago

I'm in the process of migrating my train.py to 3.0.0 (specifically for restart from checkpoint for AWS spot instance training, and once that's working I'm keen to experiment with a 4 GPU instance).

My code uses a custom dataset which was adapted from (https://github.com/UKPLab/sentence-transformers/blob/4f94c165cfcb2e619935261cb7366b0de2412365/sentence_transformers/datasets/SentenceLabelDataset.py#L12-L102) and is an torch.utils.data.IterableDataset. Am I correct that this is no longer supported - i.e. I get object has no attribute 'cache_files' and the documentation only mentions torch.utils.data.Dataset and datasets.DatasetDict

Edit: I misread the documentation - It's datasets.Dataset not a torch.utils.data.Dataset

I found a reference to BatchSamplers - is this what I should use, i.e. create a Dataset, and a custom BatchSampler that implements the sampling logic? For now I'm just using the original code to generate a new dataset that contains "anchor", "positive" pairs using round robin sampling - but that relies on the trainer not shuffling the data so I'm not 100% sure it's correct atm

david-waterworth commented 5 months ago

So it looks like I will have to implement my own BatchSampler - the original IterableDataset effectively samples 1 example of each label in a round robin fashion. Unless the seed was reset at the start of each epoch this means the examples and/or order can change on each epoch, also even if sample with replacement was false, the original code sets already_seen={} on each pass through the label_range so it can also return the same example twice in an epoch, and not return every example.

So the easiest way of directly replicating this seems to be to create a custom BatchSampler - but this doesn't seem extensible, the trainer args accepts an enum, not an instance so in order to extend it looks like I have to extend the trainer - i.e. overrode https://github.com/UKPLab/sentence-transformers/blob/4f94c165cfcb2e619935261cb7366b0de2412365/sentence_transformers/trainer.py#L446. ?

It seems more flexible to allow batch samplers to be passed via config. I think it would work if you just passed a class definition rather than an instance (i.e. samplers.NoDuplicatesBatchSampler so the trainer can instantiate it lazily later)

Also it looks the default sampler is https://github.com/UKPLab/sentence-transformers/blob/4f94c165cfcb2e619935261cb7366b0de2412365/sentence_transformers/trainer.py#L471 which afaik performs shuffling? Wont this affect CachedMultipleNegativesRankingLoss where each batch is supposed to consist only of samples with (anchor, positive) with no duplicates over positive? Is the intent here to use NoDuplicatesBatchSampler?

david-waterworth commented 5 months ago

Another update, with a bit more context to what I'm trying to do:

I have a dataset consistent of (text, label) where label is a short structured text (they're sensor types from a IoT nework, i.e. zone_temperature_sensor or zone_temperature_setpoint). There's ~4k labels and I found MultipleNegativesRankingLoss to work really well in the past with (a_i, p_i) == (text, label)

MultipleNegativesRankingLoss requires pairs (a_1, p_1), (a_2, p_2)..., (a_n, p_n) where(a_i, p_i) are a positive pair and (a_i, p_j) for i != j a negative pair

Meaning I need to construct mini-batches where the p_i's are unique within the batch. So previously I used SentenceLabelDataset with samples_per_label==1. This essentially samples batch_size labels then samples 1 example of each label. It doesn't return every example per epoch - once it's cycled through all labels it then resets already_seen

I cannot see how to do this "out-of-the box" in v3.0.0 - I monkey patched trainer so I can use a custom sampler. I couldn't use no duplicates because it will include (a_1, p_1), (a_2, p_2) if a_1 != a_2 even if p_1 == p_2. I hacked it to only consider duplicates over p_i's but the runtime was 10x slower. So I implemented SentenceLabelDataset as a BatchSampler with monkey patching and it's working as before.

Dual GPU with DDP is working very nicely though, and resume from checkpoint! So I'm close to being able to train on a 4GPU AWS spot instance.

tomaarsen commented 5 months ago

Hello!

I appreciate all the effort that you're putting in here! Indeed, unlike pre 3.0, we now primarily rely on datasets.Dataset. Beyond that, the BatchSampler code is indeed not very extensible, primarily because I didn't intend for people to have to write their own. For reference, my intention was to replace https://github.com/UKPLab/sentence-transformers/blob/4f94c165cfcb2e619935261cb7366b0de2412365/sentence_transformers/datasets/SentenceLabelDataset.py#L12-L102 with the GroupByLabel batch sampler: https://github.com/UKPLab/sentence-transformers/blob/cb35f0dc353805af68175e896adfce2b7fc3cb78/sentence_transformers/trainer.py#L463-L469

It seems more flexible to allow batch samplers to be passed via config. I think it would work if you just passed a class definition rather than an instance (i.e. samplers.NoDuplicatesBatchSampler so the trainer can instantiate it lazily later)

This was my original design, but it proved problematic once I realised that different batch samplers required different arguments: https://github.com/UKPLab/sentence-transformers/blob/cb35f0dc353805af68175e896adfce2b7fc3cb78/sentence_transformers/trainer.py#L454-L476

Also it looks the default sampler is [sic] BatchSamplers.BATCH_SAMPLER [sic] which afaik performs shuffling? Wont this affect CachedMultipleNegativesRankingLoss where each batch is supposed to consist only of samples with (anchor, positive) with no duplicates over positive? Is the intent here to use NoDuplicatesBatchSampler?

The recommendation is to use BatchSamplers.NO_DUPLICATES when using loss functions with in-batch negatives. This is equivalent to the old NoDuplicatesBatchSampler. I.e., NoDuplicatesBatchSampler inspired BatchSamplers.NO_DUPLICATES and SentenceLabelDataset inspired BatchSamplers.GROUP_BY_LABEL.

I cannot see how to do this "out-of-the box" in v3.0.0 - I monkey patched trainer so I can use a custom sampler. I couldn't use no duplicates because it will include (a_1, p_1), (a_2, p_2) if a_1 != a_2 even if p_1 == p_2. I hacked it to only consider duplicates over p_i's but the runtime was 10x slower. So I implemented SentenceLabelDataset as a BatchSampler with monkey patching and it's working as before.

You're very right, that is kind of an oversight on my part. Perhaps I should update the NoDuplicatesBatchSampler to ignore the first column, as that column will never be used as a negative in any of the current loss functions? I do agree that it might be a lot slower, as it'll just iterate over all samples until it finds "labels" that aren't already in the batch, rather than first finding enough "labels" and then sampling 1 training sample for each "label". I can try and reimplement it to be faster, but it does seem like perhaps people would expect the NO_DUPLICATES batch sampler to ignore the first column.

olivierr42 commented 5 months ago

@david-waterworth thank you for bringing this up!!

@tomaarsen Regarding streaming datasets, is there any intention to eventually support datasets.IterableDataset in the near future of SentenceTransformers?

I am dealing with fairly large pairs datasets that do not sit in memory so using datasets.Dataset is somewhat limiting in that regard. Are there easy workarounds I could implement to achieve this while still using SentenceTransformerTrainer?

Thank you!

david-waterworth commented 4 months ago

Thanks @tomaarsen - yeah no duplicates would definitely have to be made faster, it started very slow (10x slower as I mentioned) but got slower and slower as it got further into the epoch. Also one of the big differences with BatchSamplers.NO_DUPLICATES vs the original SentenceLabelDataset approach is you could reach a point where it's no longer possible to return a batch containing no duplicates before the end of the epoch (i.e. the number of epochs actually returned, and the value returned by __len__ are not consistent). I have no idea if this is important or not?

I was wondering if BatchSamplers.GROUP_BY_LABEL was what I was looking for, I'll take a closer look today. Plus I probably need to rebalance my dataset so for each class I have num_batches (assuming I set batch size to be number of classes)

Also BTW I got the transformers early stopping callback working yesterday which was nice! thanks for all the work for this major release!

david-waterworth commented 4 months ago

@tomaarsen so as an FYI/aside, I've got DDP training working on a single AWS ml.g5.12xlarge spot instance (4 GPU) with checkpointing. It took a small amount of fiddling. I couldn't get the official aws PyTorch estimator and 763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training:2.3.0-cpu-py311-ubuntu20.04-sagemaker image to work, once it got to actual training it hung. But I was originally using my own (BYOC) image and with a bit of tweaking I got that to work. Had to study their code for a bit to find how to start DDP (torchrun) with BYOC as that wasn't obvious but it's working now. Happy to create a new thread with some notes if you want?

tomaarsen commented 4 months ago

Regarding streaming datasets, is there any intention to eventually support datasets.IterableDataset in the near future of SentenceTransformers?

I've got some changes locally that allow this to work indeed! Perhaps I can get a PR ready tomorrow, but it may only be published on the 17th onwards. There's still some changes necessary, e.g. better errors (as IterableDatasetDict will not be supported, and multiple training datasets will also not be supported, as IterableDataset cant be used with BatchSampler instances), but it does work for me locally now.

@david-waterworth Glad to hear that you're able to make some progress and getting things working. Feel free to share those notes, I'm open to all feedback at this time.

ganeshkrishnan1 commented 4 months ago

Are there any workarounds to support StreamingDataset now? Our data is around 50GB and the dataset memory usage is approaching the limits of our machine at 250GB. I get this error when using StreamingDataset:

AttributeError: 'IterableDataset' object has no attribute 'cache_files'

waileong-leong commented 4 months ago

@david-waterworth are you open to sharing the sampler that you have develop to replace sentencelabeldataset?

david-waterworth commented 3 months ago

@waileong-leong sorry missed this comment, you can probably use the streaming dataset Tom added, but my work-around is below, DataReader is the original IterableDataset (a slightly customised version but I think my code should work with the original one)

class CustomBatchSampler(SetEpochMixin, BatchSampler):
    """
    Replicates original IterableDataset based sampling.
    """

    def __init__(self, dataset, batch_size: int, drop_last: bool, *args, **kwargs):
        super().__init__(dataset, batch_size, drop_last)
        self.dataset = dataset
        self.batch_size = batch_size
        self.drop_last = drop_last

        assert self.drop_last, "not implemented"

    def __iter__(self):
        logging.debug("Preparing dataset for sampling...")
        reader = DataReader(self.dataset)

        logging.debug("Prepared dataset for sampling.")
        batch = 0
        while batch < len(self):
            yield list(itertools.islice(reader, self.batch_size))
            batch += 1

    def __len__(self) -> int:
        if self.drop_last:
            return len(self.dataset) // self.batch_size
        else:
            return (len(self.dataset) + self.batch_size - 1) // self.batch_size