Accelerate local LLM inference and finetuning (LLaMA, Mistral, ChatGLM, Qwen, Baichuan, Mixtral, Gemma, Phi, MiniCPM, etc.) on Intel XPU (e.g., local PC with iGPU and NPU, discrete GPU such as Arc, Flex and Max); seamlessly integrate with llama.cpp, Ollama, HuggingFace, LangChain, LlamaIndex, GraphRAG, DeepSpeed, vLLM, FastChat, Axolotl, etc.
Apache License 2.0
6.55k
stars
1.25k
forks
source link
Orca: Support customized training in pytorch estimator with Hook Function #6104
We need to provide users with a overwritable interface to implement custom training strategies in each stage of training. In current designs Hooks Function are commonly used in existing third-party training frameworks(pytorch-lightning, mmcv, detectron2 etc.) to achieve this.
Related Work
Here is how pl and mmcv implement hook mechanism:
pytorch-lightning expect users to overrides different stages hook functions of base LightningModule, and they will be called later in Trainer with corresponding args. A Trainer manages a TrainingEpochLoop, which is an abstraction of the training process and consists of BatchLoop, OptimizerLoop and so on,will call these hooks frequently and Timely in epoch loops and batch loops.
class LitClassifer(LightningModule):
def forward(self, x):
...
# called when training batches
def training_step(self, batch, batch_idx):
...
# called just after a batch is fed and a output is obtained
def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
# customized step function, if lr is related to epoch or batch_idx now
def optimizer_step(
self,
epoch: int,
batch_idx: int,
optimizer: Union[Optimizer, LightningOptimizer],
optimizer_idx: int = 0,
optimizer_closure: Optional[Callable[[], Any]] = None,
on_tpu: bool = False,
using_native_amp: bool = False,
using_lbfgs: bool = False,
) -> None:
This implementation is not very flexible due to the limited parameters of the hook function, and such an implementation on orca is not that friendly to our old design.
mmcv provides a Hook class for each component to achieve their behavior at different stages. It is a runner object contains all components(just like our torchrunner) required in training, and expects registrations of Hook objects if its component has some customized behaviors. So once it reaches each stage point, these registered hooks will be called. Note that the entire runner will be passed to each hook functions.
Motivation
We need to provide users with a overwritable interface to implement custom training strategies in each stage of training. In current designs
Hooks Function
are commonly used in existing third-party training frameworks(pytorch-lightning, mmcv, detectron2 etc.) to achieve this.Related Work
Here is how pl and mmcv implement hook mechanism:
pytorch-lightning expect users to overrides different stages hook functions of base LightningModule, and they will be called later in Trainer with corresponding args. A Trainer manages a TrainingEpochLoop, which is an abstraction of the training process and consists of BatchLoop, OptimizerLoop and so on,will call these hooks frequently and Timely in epoch loops and batch loops.
For example, optimizer steps in OptimizerLoop:
This implementation is not very flexible due to the limited parameters of the hook function, and such an implementation on orca is not that friendly to our old design.
mmcv provides a Hook class for each component to achieve their behavior at different stages. It is a runner object contains all components(just like our torchrunner) required in training, and expects registrations of Hook objects if its component has some customized behaviors. So once it reaches each stage point, these registered hooks will be called. Note that the entire runner will be passed to each hook functions.
Hooks should be registered in runner before fitting:
And hooks will be call in fixed stage points:
But it's hard to get intermediate outputs like batch and loss with such hook function signature(with a single runner) in mmcv.
Our Design
In order to keep the existing implementation, I prefer mmcv style hook function. We suggest users to implement component hook class like:
In pyspark and ray estimator:
And we call them in TrainOperator:
Any suggestion? @jason-dai @hkvision