pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.87k stars 22.61k forks source link

Revisit ONNXProgram API due to fake tensor support and additional model_with_state_dict addition #115461

Closed thiagocrepaldi closed 2 months ago

thiagocrepaldi commented 11 months ago

The current design of torch.onnx.dynamo_export API when combined with PyTorch's FakeTensor produces a ExportedProgram which does not have all the metadata needed to execute the model.

When ONNXProgram is produced by exporting a non-fake model, ONNXProgram is capable of producing a complete ONNXProto, but when ONNXProgram is produced by exporting a fake model, ONNXProgram, the user must provide a real (non-fake) model with the desired state.

Today, in order to generate the same ONNXProto from either fake or non-fake models, ONNXProgram APIs required extra argument:

This issue is intended to investigate alternative solutions for that problem duality of ONNXProgram, that is, a way to handle ONNXProgram objects when they either contains model state or not, having the user experience as an important requirement

PRs related to this issue:

BowenBao commented 11 months ago

In addition, model_torch was introduced as argument to ONNXProgram initialization.

The problem also involves the additional concept of lifted initializers introduced from torch.export.export. Both lifted initializers and fake exporting create ONNXProto that miss out tensor states.

titaiwangms commented 9 months ago

Some findings on exploring huggingface API: from_pretrained: We do have non-persistent buffers in .bin files, so if to extend support to hf users, we could accept for example .bin in torch_state_dicts input to get rid of the adapt_torch_input hack.

def save_model_with_external_data(
    basepath: str,
    model_location: str,
    initializer_location: str,
    torch_state_dicts: Tuple[Union[dict, str, io.BytesIO], ...],
    onnx_model: onnx.ModelProto,  # type: ignore[name-defined]
    rename_initializer: bool = False,
) -> None:
thiagocrepaldi commented 9 months ago

Some findings on exploring huggingface API: from_pretrained: We do have non-persistent buffers in .bin files, so if to extend support to hf users, we could accept for example .bin in torch_state_dicts input to get rid of the adapt_torch_input hack.

def save_model_with_external_data(
    basepath: str,
    model_location: str,
    initializer_location: str,
    torch_state_dicts: Tuple[Union[dict, str, io.BytesIO], ...],
    onnx_model: onnx.ModelProto,  # type: ignore[name-defined]
    rename_initializer: bool = False,
) -> None:

Yup, that is aligned with my suggestion in renaming model_state_dict to checkpoint or another generic term that allows for paths or dicts with complete state of a model, not only torch.nn.Module.state_dict

However, we should also consider cases where the model is not from hf; how do we error out when missing parameters are not in the checkpoint? or else how do we allow users to specify what is missing? the latter becomes less important if ExportedProgram supports non-persist buffers as they are trying lately

justinchuby commented 2 months ago

Will revisit later. At the first sight, additional model weights can be easily applied to the onnx model as initializers after export.