intel-analytics / ipex-llm

Accelerate local LLM inference and finetuning (LLaMA, Mistral, ChatGLM, Qwen, Baichuan, Mixtral, Gemma, Phi, MiniCPM, etc.) on Intel XPU (e.g., local PC with iGPU and NPU, discrete GPU such as Arc, Flex and Max); seamlessly integrate with llama.cpp, Ollama, HuggingFace, LangChain, LlamaIndex, GraphRAG, DeepSpeed, vLLM, FastChat, Axolotl, etc.
Apache License 2.0
6.63k stars 1.26k forks source link

Nano: basic pytorch-lightning "wrapper" design #3272

Closed TheaperDeng closed 2 years ago

TheaperDeng commented 3 years ago

@yangw1234

related to #3171

Since we have decided to provide a pytorch-lightning "wrapper" inside nano rather than Chronos to help our users transform their pytorch nn.module to a LightningModule that can be accelerated by bigdl-nano.

We (@zhentaocc) proposed a decorator @basic_lightningmodule to help our users to transform their pytorch nn.module (required to have normal training loop). We may also add other decorators for GAN and other training loop in the future.

API Design

from bigdl.nano.pytorch import basic_lightningmodule

@basic_lightningmodule(loss=loss, optim=optim, **optim_configs)  # the only thing you need to do
class Net(nn.module):
    ...

# This will transform Net(nn.module) to a DEFAULT_PL_MODULE_NAME(LightningModule)

additionally, if you need to use onnxruntime support, do it consistently! #3174

from bigdl.nano.pytorch.onnx import onnxruntime_support

@onnxruntime_support
@basic_lightningmodule(loss=loss, optim=optim, **optim_configs)  # the only thing you need to do
class Net(nn.module):
    ...

# This will transform Net(nn.module) to a Net(LightningModule) with onnxruntime support

Possible implementation

# bigdl/nano/
def basic_lightningmodule(loss=loss, optim=optim, **optim_configs)
    '''
    :param loss: a `nn._loss` class / a `nn._loss` instance
    :param optim: a `nn.optim` class
    :param **optim_configs: the additional settings to be sent to optim
    '''
    def parametrized(cls):
        class DEFAULT_PL_MODULE_NAME(pl.LightningModule):
            def __init__(self, *args, **kwargs):
                self.model = cls(*args, **kwargs)
                self.loss = loss()  # if a class
                self.optim = optim(**optim_configs)
    return parametrized

We need **optim_configs because that we need to set important hparam such as lr (although the default one is good enough). And a optim can not be built before a model is built.

yangw1234 commented 3 years ago

LGTM One comment,

@basic_lightningmodule(loss_creator=loss, optim_creator=optim, config)

should we put make it all keywords argument? I think position arugment should be put in front of keywords arguments.

TheaperDeng commented 3 years ago

LGTM One comment,

@basic_lightningmodule(loss_creator=loss, optim_creator=optim, config)

should we put make it all keywords argument? I think position arugment should be put in front of keywords arguments.

sure, it's actually a typo

and there seems to be many typos :(

zhentaocc commented 3 years ago

This decorator is implemented in #3181 . @yangw1234 Will add a simple unit test soon.