Yard1 / ray-skorch

Distributed skorch on Ray Train
Apache License 2.0
57 stars 3 forks source link

Observability and checkpointing #9

Closed Yard1 closed 2 years ago

Yard1 commented 2 years ago

This PR adds observability (logging, profiling and output) features and checkpointing support. Furthermore, it also allows for Modules with multi-input forward calls (feature parity with skorch).

Most of the functionality is implemented through callbacks - both for Skorch (running on each worker) and for Ray Train. A common pattern is that skorch callbacks persist necessary data (either or disk or in skorch history objects), which is then reported to Train.

The skorch callback API is extended with additional notify points. However, normal skorch callbacks will also work.

Outputting progress information has been moved from skorch to Train callbacks. The DetailedHistoryPrintCallback, enabled if profile=True outputs raw data from every worker in addition to aggregated metrics, while TableHistoryPrintCallback (enabled otherwise) provides a human friendly output identical to the default Skorch output. Both callbacks can be configured to hide/display different keys.

Checkpointing is facilitated through the TrainCheckpoint skorch callback (based on skorch checkpointing logic). It can both save and load checkpoints. A checkpoint object can be passed as an argument to fit.

PyTorch profiling is done through PytorchProfilerLogger skorch callback, and TBXProfilerCallback train callback. The former runs the pytorch profiler on every worker and reports its output to Train. The train callback then merely saves the said output on the head node after training completes.

Additional information not related to this PR

The main part of Train-Sklearn is the data handling. This is done through the RayDataset class, inheriting from Skorch's Dataset. The workflow is as follows: