coralnet / pyspacer

Python based tools for spatial image analysis
MIT License
6 stars 2 forks source link

Training: image-based batch size can lead to very uneven batches #78

Closed StephenChan closed 6 months ago

StephenChan commented 7 months ago

Mini-batch size calculation:

    # Calculate max nbr images to keep in memory (for 5000 samples total).
    max_imgs_in_memory = 5000 // labels.samples_per_image

samples_per_image is taken from the first image in the input:

    @property
    def samples_per_image(self):
        return len(next(iter(self.data.values())))

So I tried a training job where the first image has 10 points, followed by a bunch of images with 1000 points (unheard of for a single CoralNet source, but possible if multiple sources are incorporated). Logs confirmed that it decided on 500 images per mini-batch. Then it got this:

Traceback (most recent call last):
  File "/workspace/spacer/spacer/tasks.py", line 122, in process_job
    results.append(run[job_msg.task_name](task))
  File "/workspace/spacer/spacer/tasks.py", line 45, in train_classifier
    clf, val_results, return_message = trainer(
  File "/workspace/spacer/spacer/train_classifier.py", line 50, in __call__
    clf, ref_accs = train(train_labels, feature_loc, nbr_epochs, clf_type)
  File "/workspace/spacer/spacer/train_utils.py", line 83, in train
    clf.partial_fit(x, y, classes=classes)
  File "/usr/local/lib/python3.10/dist-packages/sklearn/neural_network/_multilayer_perceptron.py", line 1215, in partial_fit
    super().partial_fit(X, y)
  File "/usr/local/lib/python3.10/dist-packages/sklearn/neural_network/_multilayer_perceptron.py", line 790, in partial_fit
    return self._fit(X, y, incremental=True)
  File "/usr/local/lib/python3.10/dist-packages/sklearn/neural_network/_multilayer_perceptron.py", line 394, in _fit
    X, y = self._validate_input(X, y, incremental, reset=first_pass)
  File "/usr/local/lib/python3.10/dist-packages/sklearn/neural_network/_multilayer_perceptron.py", line 1109, in _validate_input
    X, y = self._validate_data(
  File "/usr/local/lib/python3.10/dist-packages/sklearn/base.py", line 596, in _validate_data
    X, y = check_X_y(X, y, **check_params)
  File "/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py", line 1074, in check_X_y
    X = check_array(
  File "/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py", line 856, in check_array
    array = np.asarray(array, order=order, dtype=dtype)
numpy.core._exceptions._ArrayMemoryError: Unable to allocate 4.53 GiB for an array with shape (475192, 1280) and data type float64

This machine had close to 16 GB RAM, so maybe there were multiple similarly large arrays in memory at the time (which might indicate a bit of memory optimization to be done in the code). But the point is, the batch size is supposed to be part of the instance-requirements design. Also, I imagine that the intent of the training algorithm is to have somewhat uniform batch sizes, and having them vary so much may be considered an improper implementation with undefined behavior. So making the batch size consistent (in annotation count) regardless of variable points per image seems worthwhile, as long as it's not a big implementation chore.

I've actually been working on this, issue #59, and issue #60 together as they touch common parts of the code, but I wanted to document this behavior while I was at it.

StephenChan commented 6 months ago

Resolved in PR #71. The batch size no longer cares about points per image, and in addition, it no longer cares about image 'boundaries' at all. So a single image's points can go into multiple batches. A generator function made this decently clean to code (IMO).

Note that the preprocess_labels() train/ref/val split is still done on the image level though, not the point level. But I think that's acceptable because: it may make the sets simpler to reason about, the potential unevenness isn't nearly as bad as the original issue, and it's still possible to make your own train/ref/val split that has a single image in multiple sets.