intel-analytics / ipex-llm

Accelerate local LLM inference and finetuning (LLaMA, Mistral, ChatGLM, Qwen, Baichuan, Mixtral, Gemma, Phi, MiniCPM, etc.) on Intel XPU (e.g., local PC with iGPU and NPU, discrete GPU such as Arc, Flex and Max); seamlessly integrate with llama.cpp, Ollama, HuggingFace, LangChain, LlamaIndex, GraphRAG, DeepSpeed, vLLM, FastChat, Axolotl, etc.
Apache License 2.0
6.55k stars 1.25k forks source link

Orca: Support customized training in pytorch estimator with Hook Function #6104

Closed leonardozcm closed 1 year ago

leonardozcm commented 1 year ago

Motivation

We need to provide users with a overwritable interface to implement custom training strategies in each stage of training. In current designs Hooks Function are commonly used in existing third-party training frameworks(pytorch-lightning, mmcv, detectron2 etc.) to achieve this.

Related Work

Here is how pl and mmcv implement hook mechanism:

  1. pytorch-lightning expect users to overrides different stages hook functions of base LightningModule, and they will be called later in Trainer with corresponding args. A Trainer manages a TrainingEpochLoop, which is an abstraction of the training process and consists of BatchLoop, OptimizerLoop and so on,will call these hooks frequently and Timely in epoch loops and batch loops.

    class LitClassifer(LightningModule):
    def forward(self, x):
         ...
    
    # called when training batches
    def training_step(self, batch, batch_idx):
         ...
    
    # called just after a batch is fed and a output is obtained
    def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
    
    # customized step function, if lr is related to epoch or batch_idx now
    def optimizer_step(
        self,
        epoch: int,
        batch_idx: int,
        optimizer: Union[Optimizer, LightningOptimizer],
        optimizer_idx: int = 0,
        optimizer_closure: Optional[Callable[[], Any]] = None,
        on_tpu: bool = False,
        using_native_amp: bool = False,
        using_lbfgs: bool = False,
    ) -> None:

    For example, optimizer steps in OptimizerLoop:

    def _optimizer_step(..):
       ...
        self.trainer._call_lightning_module_hook(
            "optimizer_step",
            self.trainer.current_epoch,
            batch_idx,
            optimizer,
            opt_idx,
            train_step_and_backward_closure,
            on_tpu=isinstance(self.trainer.accelerator, TPUAccelerator),
            using_native_amp=(self.trainer.amp_backend == AMPType.NATIVE),
            using_lbfgs=is_lbfgs,
        )
       ...

    This implementation is not very flexible due to the limited parameters of the hook function, and such an implementation on orca is not that friendly to our old design.

  2. mmcv provides a Hook class for each component to achieve their behavior at different stages. It is a runner object contains all components(just like our torchrunner) required in training, and expects registrations of Hook objects if its component has some customized behaviors. So once it reaches each stage point, these registered hooks will be called. Note that the entire runner will be passed to each hook functions.

    class Hook:
    stages = ('before_run', 'before_train_epoch', 'before_train_iter',
              'after_train_iter', 'after_train_epoch', 'before_val_epoch',
              'before_val_iter', 'after_val_iter', 'after_val_epoch',
              'after_run')
    
    def before_run(self, runner):
        pass
    
    def before_epoch(self, runner):
        pass
    
    def before_iter(self, runner):
        pass
    
    def before_train_epoch(self, runner):
        self.before_epoch(runner)
    
    def before_val_epoch(self, runner):
        self.before_epoch(runner)
    ...

    Hooks should be registered in runner before fitting:

    def register_training_hooks(
            self,
            lr_config: Union[Dict, Hook, None],
            optimizer_config: Union[Dict, Hook, None] = None, ...):
        self.register_lr_hook(lr_config)
        self.register_optimizer_hook(optimizer_config)
        ...

    And hooks will be call in fixed stage points:

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
             ...

    But it's hard to get intermediate outputs like batch and loss with such hook function signature(with a single runner) in mmcv.

Our Design

In order to keep the existing implementation, I prefer mmcv style hook function. We suggest users to implement component hook class like:

class Hook:

    def before_run(self, args1, arg2, args3, ...): # for example: model, batch_idx, epoch_idx...
        pass

    def before_epoch(self, arg1, arg2, args3, ...):
        pass

    def before_iter(self, arg1, arg2, args3, ...):
        pass

In pyspark and ray estimator:

   est = Estimator.from_torch(model,
                              optimizer,
                               ...,
                               hooks={"model_hook": hook1, "optimizor_hook": hook2, "scheduler_hook": hook3}
                               )

And we call them in TrainOperator:

class TrainingOperator:
    def _train_loop(self, iterator, info, _progress_bar, metric_meters, callbacks):

       for batch_idx, batch in enumerate(iterator):
            batch_info = {
                "batch_idx": batch_idx,
                "global_step": self.global_step
            }
            batch_info.update(info)
            # call all before_iter hook functions
            self.call_hooks("before_iter", epoch=epoch, other kwrargs...) 
            metrics = self.train_batch(batch, batch_info=batch_info)

   def call_hooks(func_name, **kwargs):
         for hook in self.hooks:
               fn=get_attr(hook, func_name)
               fn(**kwargs)

Any suggestion? @jason-dai @hkvision

leonardozcm commented 1 year ago

we enable mmcv_estimator fast.