Closed thiagocrepaldi closed 2 months ago
In addition, model_torch
was introduced as argument to ONNXProgram
initialization.
The problem also involves the additional concept of lifted initializers introduced from torch.export.export
. Both lifted initializers and fake exporting create ONNXProto
that miss out tensor states.
Some findings on exploring huggingface API: from_pretrained: We do have non-persistent buffers in .bin files, so if to extend support to hf users, we could accept for example .bin in torch_state_dicts
input to get rid of the adapt_torch_input
hack.
def save_model_with_external_data(
basepath: str,
model_location: str,
initializer_location: str,
torch_state_dicts: Tuple[Union[dict, str, io.BytesIO], ...],
onnx_model: onnx.ModelProto, # type: ignore[name-defined]
rename_initializer: bool = False,
) -> None:
Some findings on exploring huggingface API: from_pretrained: We do have non-persistent buffers in .bin files, so if to extend support to hf users, we could accept for example .bin in
torch_state_dicts
input to get rid of theadapt_torch_input
hack.def save_model_with_external_data( basepath: str, model_location: str, initializer_location: str, torch_state_dicts: Tuple[Union[dict, str, io.BytesIO], ...], onnx_model: onnx.ModelProto, # type: ignore[name-defined] rename_initializer: bool = False, ) -> None:
Yup, that is aligned with my suggestion in renaming model_state_dict
to checkpoint
or another generic term that allows for paths or dicts with complete state of a model, not only torch.nn.Module.state_dict
However, we should also consider cases where the model is not from hf; how do we error out when missing parameters are not in the checkpoint? or else how do we allow users to specify what is missing? the latter becomes less important if ExportedProgram supports non-persist buffers as they are trying lately
Will revisit later. At the first sight, additional model weights can be easily applied to the onnx model as initializers after export.
The current design of
torch.onnx.dynamo_export
API when combined with PyTorch's FakeTensor produces aExportedProgram
which does not have all the metadata needed to execute the model.When
ONNXProgram
is produced by exporting a non-fake model,ONNXProgram
is capable of producing a completeONNXProto
, but whenONNXProgram
is produced by exporting a fake model,ONNXProgram
, the user must provide a real (non-fake) model with the desired state.Today, in order to generate the same
ONNXProto
from either fake or non-fake models,ONNXProgram
APIs required extra argument:ONNXProgram.save(..., model_with_state_dict=None)
during model persistence.ONNXProgram.__call__(...)
andONNXProgram.adapt_torch_inputs_to_onnx(...)
will also require the same argument.This issue is intended to investigate alternative solutions for that problem duality of
ONNXProgram
, that is, a way to handleONNXProgram
objects when they either contains model state or not, having the user experience as an important requirementPRs related to this issue: