kostaleonard / mlops

A framework for conducting MLOps.
MIT License
3 stars 0 forks source link

DataProcessor objects with super() calls fail to pickle #49

Open kostaleonard opened 2 years ago

kostaleonard commented 2 years ago

DataProcessor objects that make super() calls fail to pickle. This is an issue in dill that was apparently resolved in a recent PR, but there has not been a new release of the package to incorporate this change.

kostaleonard commented 2 years ago

As a workaround, you can do the following. This example is from tests/dataset/doubled_preset_data_processor.py.

Original:

class DoubledPresetDataProcessor(PresetDataProcessor):
    """Processes a preset dataset, with no file I/O; doubles tensor values."""

    def get_raw_features_and_labels(self, dataset_path: str) -> \
            Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
        """Returns doubled preset raw feature and label tensors.

        :param dataset_path: Unused
        :return: A 2-tuple of the features dictionary and labels dictionary,
            with matching keys and ordered tensors.
        """
        # This will cause a pickling error.
        features, labels = super().get_raw_features_and_labels(dataset_path)
        for name, tensor in features.items():
            features[name] = 2 * tensor
        return features, labels

Workaround:

class DoubledPresetDataProcessor(PresetDataProcessor):
    """Processes a preset dataset, with no file I/O; doubles tensor values."""

    def get_raw_features_and_labels(self, dataset_path: str) -> \
            Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
        """Returns doubled preset raw feature and label tensors.

        :param dataset_path: Unused
        :return: A 2-tuple of the features dictionary and labels dictionary,
            with matching keys and ordered tensors.
        """
        # See #49 for why we can't use super().
        features, labels = PresetDataProcessor.get_raw_features_and_labels(
            self, dataset_path)
        for name, tensor in features.items():
            features[name] = 2 * tensor
        return features, labels