This PR adds observability (logging, profiling and output) features and checkpointing support. Furthermore, it also allows for Modules with multi-input forward calls (feature parity with skorch).
Most of the functionality is implemented through callbacks - both for Skorch (running on each worker) and for Ray Train. A common pattern is that skorch callbacks persist necessary data (either or disk or in skorch history objects), which is then reported to Train.
The skorch callback API is extended with additional notify points. However, normal skorch callbacks will also work.
Outputting progress information has been moved from skorch to Train callbacks. The DetailedHistoryPrintCallback, enabled if profile=True outputs raw data from every worker in addition to aggregated metrics, while TableHistoryPrintCallback (enabled otherwise) provides a human friendly output identical to the default Skorch output. Both callbacks can be configured to hide/display different keys.
Checkpointing is facilitated through the TrainCheckpoint skorch callback (based on skorch checkpointing logic). It can both save and load checkpoints. A checkpoint object can be passed as an argument to fit.
PyTorch profiling is done through PytorchProfilerLogger skorch callback, and TBXProfilerCallback train callback. The former runs the pytorch profiler on every worker and reports its output to Train. The train callback then merely saves the said output on the head node after training completes.
Additional information not related to this PR
The main part of Train-Sklearn is the data handling. This is done through the RayDataset class, inheriting from Skorch's Dataset. The workflow is as follows:
User passes a ray.data.Dataset, ray.data.DatasetPipeline, a numpy array, a pandas array, or a list/dict of the previous two (for multi-input)
The RayDataset class converts to ray.data if necessary and stores any required metadata
The FixedSplit class takes the ray.data dataset and splits it into train and validation ray.data.DatasetPipelines
Those ray.data.DatasetPipelines are then wrapped in RayPipelineDataset for skorch compatibility
The converted ray.data.DatasetPipeline from RayPipelineDataset are then passed to trainer.train() with metadata in config
The RayPipelineDatasets are recreated on every worker. Training begins.
During training, the PipelineIterator takes in the RayPipelineDatasets and outputs batches as torch.Tensors, which are then used internally by skorch
This PR adds observability (logging, profiling and output) features and checkpointing support. Furthermore, it also allows for Modules with multi-input
forward
calls (feature parity with skorch).Most of the functionality is implemented through callbacks - both for Skorch (running on each worker) and for Ray Train. A common pattern is that skorch callbacks persist necessary data (either or disk or in skorch
history
objects), which is then reported to Train.The skorch callback API is extended with additional notify points. However, normal skorch callbacks will also work.
Outputting progress information has been moved from skorch to Train callbacks. The
DetailedHistoryPrintCallback
, enabled ifprofile=True
outputs raw data from every worker in addition to aggregated metrics, whileTableHistoryPrintCallback
(enabled otherwise) provides a human friendly output identical to the default Skorch output. Both callbacks can be configured to hide/display different keys.Checkpointing is facilitated through the
TrainCheckpoint
skorch callback (based on skorch checkpointing logic). It can both save and load checkpoints. A checkpoint object can be passed as an argument tofit
.PyTorch profiling is done through
PytorchProfilerLogger
skorch callback, andTBXProfilerCallback
train callback. The former runs the pytorch profiler on every worker and reports its output to Train. The train callback then merely saves the said output on the head node after training completes.Additional information not related to this PR
The main part of Train-Sklearn is the data handling. This is done through the
RayDataset
class, inheriting from Skorch'sDataset
. The workflow is as follows:ray.data.Dataset
,ray.data.DatasetPipeline
, a numpy array, a pandas array, or a list/dict of the previous two (for multi-input)RayDataset
class converts toray.data
if necessary and stores any required metadataFixedSplit
class takes theray.data
dataset and splits it into train and validationray.data.DatasetPipeline
sray.data.DatasetPipeline
s are then wrapped inRayPipelineDataset
for skorch compatibilityray.data.DatasetPipeline
fromRayPipelineDataset
are then passed totrainer.train()
with metadata in configRayPipelineDataset
s are recreated on every worker. Training begins.PipelineIterator
takes in theRayPipelineDataset
s and outputs batches as torch.Tensors, which are then used internally by skorch