Open rbavery opened 5 days ago
Although torch.compile
is not recommended, I wouldn't remove it completely, since some existing artifacts might exist and still depend on it (e.g.: older code that doesn't work with newer methods). Instead, I think we should add a (*) + "WARNING" block under the https://github.com/stac-extensions/mlm?tab=readme-ov-file#artifact-type-enum to "STRONGLY" advise migrating to other more appropriate artifact types. Basically, removing it might cause more problems by obfuscation rather than explicitly mentioning why it is bad practice.
As for the other proposed types, I agree. We need more non-pytorch examples.
The examples don't have to be an exhaustive list. The mlm:artifact_type
field is not even explicitly defined in the JSON schema at the moment. We should define it as a lax string, similar to other definitions:
https://github.com/stac-extensions/mlm/blob/25bef80362629188c78938277643a2c53659163b/json-schema/schema.json#L325-L366
Although torch.compile is not recommended, I wouldn't remove it completely, since some existing artifacts might exist and still depend on it (e.g.: older code that doesn't work with newer methods).
I'm not saying that torch.compile is deprecated or anything, I'm saying it was never and currently is not used for saving model artifacts. so it shouldn't be listed as an artifact type. It is a very new and novel tool built for a different purpose: for optimizing eager code, not saving out full model graphs (model artifacts)
I am confident that nobody is producing model artifacts with torch.compile, since it is by design unable to do this. The torch docs state this. torch.export and torch.compile share optimization code paths (Torch Inductor), but only torch.export produces model artifacts.
I disagree.
torch.compile
can use the same .pt2
as torch.export
.
The only issue (which torch.export
is trying to address) is that torch.compile
can break when a certain reference cannot be resolved and trying to fall back to the Python implemention, which might not exist if the package defining it is missing.
Because torch.export
is still in prototype feature in development, we should still give the option to use torch.compile
. Furthermore, there are still official documentation of pytorch using this exact strategy: https://github.com/pytorch/serve/blob/master/examples/pt2/README.md#torchcompile
Here's the example you linked from torchserve, which is discussing packaging both model code (that is optimized with torch.compile) and a state_dict (.pt), which is the model artifact file.
# 1. Convert a regular module to an optimized module
opt_mod = torch.compile(mod)
# 2. Train the optimized module
# ....
# 3. Save the opt module state dict
torch.save(opt_model.state_dict(), "model.pt")
# 4. Reload the model
mod = torch.load(model)
# 5. Compile the module and then run inferences with it
opt_mod = torch.compile(mod)
from that torchserve doc:
torchserve takes care of 4 and 5 for you while the remaining steps are your responsibility. You can do the exact same thing on the vast majority of TIMM or HuggingFace models.
In this example, torch.compile needs to be called twice, at train time and inference time. this is because the artifact (model.pt) that is saved does not bake in the optimizations from torch.compile.
therefore, torch.compile is not used to produce model artifacts. the model can be deployed without torch.compile. torch.compile is a detail about the runtime dependencies for modls that depend on eager Pytorch code.
torch.compile can break when a certain reference cannot be resolved and trying to fall back to the Python implemention
this is true but it only happens in eager mode because there's no way to save artifacts that bake in these torch.compile optimizations into a model artifact. to do that, you need torch.export.
to summarize the above, torch.compile optimizes a model's graph, but doesn't handle saving the model graph + state_dict weights. I think it'd be great to have a place in the mlm extension for calling out what hardware specific optimizations have been applied to an inference pipeline that depends on nn.Module or similar ml framework constructs. But I don't think the model artifact section is the right place to do that.
My intuition would be that mlm:accelerator
, mlm:accelerator_constrained
, etc. would be used in combination with the corresponding state dict generated for a given hardware optimization. Beyond that, I don't think it is the role of MLM anymore to be explicit about each parameter, and something like torch.export
that addresses those issues should be used instead if that really matters.
The idea of specifying mlm:artifact_type = torch.compile
instead of directly using mlm:artifact_type = torch.save
is simply to hint, as you mentioned and as in the example code, that this step was performed before torch.save
to obtain the .pt2
, and that reloading the model should ideally do torch.load + torch.compile
rather than only torch.load
. Since the compile step could modify which inference operations actually happen and which results are obtained, such as by affecting digit precisions for example, replicable results need to consider this hint that affects the obtained state dict. In other words, mlm:artifact_type
doesn't only indicate the "content-type" of the artifact, but how it is intended to be used as well.
Now, I'm not saying torch.compile
is necessarily the "right way", but it is "a way", and I prefer MLM indicating it to guide best practices (and warn against bad-practices for that matter).
torch.compile doesn't affect the state_dict at all though. it has no bearing on the contents of the file from torch.save. I don't think we should guide users to call torch.compile an artifact type because it has no bearing on the contents of artifacts saved with torch.save.
also it is not the convention to save state_dict weights as ".pt2". of course it is possible but the torch docs state it is convention to only use .pt2 from torch.export. I could call a torch.export file .pt just like I could call a .tiff file a .txt but that's be misleading and go against convention.
The same docs also says:
torch.compile() also utilizes the same PT2 stack as
torch.export
and
This feature is a prototype under active development and there WILL BE BREAKING CHANGES in the future.
Therefore, the convention is not really well-established. I cannot consciously force users toward torch.export
until the feature is stable, and the workaround is currently to use torch.compile
that has been available for a long time.
As mentioned previously,
mlm:artifact_type
doesn't only indicate the "content-type" of the artifact, but how it is intended to be used as well
Therefore, even if the state_dict
would be identical with/without torch.compile
, this is a hint telling you explicitly that the reproducible runtime using the provided .pt[h][2]
file expects the user to do torch.compile
before running model inference, and if not done, results might differ (either in terms of prediction values, runtime speed, etc.). If we could only rely on the file type to infer how to employ the model, we wouldn't need the mlm:artifact_type
. This is exactly what https://github.com/stac-extensions/mlm#artifact-type-enum is already mentioning.
Again, I have no issue about adding more references to explanations/recommendations for using a preferred approach over another in the best-practices for better reproducibility, but I do not think this is a sufficient reason to remove torch.compile
entirely. There are some scripts out there that still use torch.compile
, and they must be able to indicate it somehow. I would be happy to link to any official pytorch guidance in the best-practices document about which approach is better and why to use it, but this explanation should most definitely live in pytorch docs, not in MLM, since the extension intends to be framework-agnostic.
mlm:artifact_type doesn't only indicate the "content-type" of the artifact, but how it is intended to be used as well
maybe this is the core issue. are we overloading artifact_type with these two concepts: "content-type" and "intended use"? I think so. And I think it will confuse users looking at MLM metadata.
Therefore, even if the state_dict would be identical with/without torch.compile, this is a hint telling you explicitly that the reproducible runtime using the provided .pt[h][2] file expects the user to do torch.compile before running model inference, and if not done, results might differ (either in terms of prediction values, runtime speed, etc.).
this seems like a lot to infer. I'd rather have a separate field that marks what inference optimizations are made in the inference pipeline that don't affect the content type. then it would be clear that the content-type field indicates how to load the model and the other, optimization focused field dictates the suggested inference code path to use at inference time.
I also have no issue with adding more recommendations and I don't want to remove mention of torch.compile. I'm fine even with removing torch.export until the API is solidified and stable. I also don't think torch.compile is out of date or deprecated or it's an either or between torch.compile and torch.export, they serve different purposes. I just don't want users to think that we are hinting that it defines a unique content type.
My motivation for being more explicit here wrt content-type vs. properties of the inference pipeline is that earlier this year I was very confused going through the Pytorch documentation on how to produce a compiled model artifact that had no runtime dependencies other than Pytorch. I think it is unfortunate that the docs do not make it clear that torch.compile
and torch.export
are not used for the same purposes even though they share the underlying compile workflow.
I spent a good amount of time thinking torch.compile would produce the model artifact I needed when it wasn't the right tool. I don't want to mislead users into doing the same.
overloading artifact_type with these two concepts: "content-type" and "intended use"
Somewhat (and maybe even more than 2 concepts).
The "content-type" should really be indicated by the type
, as detailed in https://github.com/stac-extensions/mlm#model-artifact-media-type. However, because .pt
(and its corresponding content media-type), could refer to different (and overlapping) encoding/frameworks/etc., the type
alone is not enough. This is where the mlm:artifact_type
comes in to add more context on top to disambiguate similar uses of the same type
. The type
and mlm:artifact_type
should somewhat align naturally.
Note that https://github.com/stac-extensions/mlm#artifact-type-enum also explicitly indicates that the names like torch.compile
/torch.export
are used explicitly to give a very obvious hint about what should be done with the artifact file. However, this is an arbitrary value. Using mlm:artifact_type: "created with torch.save and should be loaded using torch.load followed by torch.compile"
would be technically just as valid, but much less appealing and convenient to use.
The reason a single mlm:artifact_type
property is used is that, in pretty much any other case, we would inject framework-dependant concepts. For example, if we defined some kind of mlm:artifact_compile: true|false
property to indicate whether torch.compile
is used or not, what happens for frameworks that have no concept "compile"? What if they do, but that refers to something else entirely? Then, MLM definitions are not meaningful anymore in terms of searching/filtering STAC items. I want to avoid having any sort of listing of specific properties from all possible frameworks, because it becomes a maintenance burden and goes against the unifying intention of MLM.
If usage is not obvious from the type
and mlm:artifact_type
alone, then documentation should be provided, either as a description
property, or strait-up an accompanying reference code/script using roles: [mlm:source_code]
that makes usage obvious.
Personally, I think this is such a specific edge-case, that we would be better off by just adding a "⚠️ (see best-practice)" next to torch.compile
in the table, and provide some explanations of the intended use. Then, users are free to follow that recommendation or not.
Personally, I think this is such a specific edge-case
Pytorch, Tensorflow, and JAX all provide mechanisms for JIT and AOT compiled models. JIT and AOT models have very different deployment environments and level of effort from those looking to run the models. I think it is important that the MLM capture this variation, it's probably something users even want to search on because "level of effort to try the model out" is often a first order concern.
Examples of AOT and JIT options besides Pytorch https://www.tensorflow.org/tutorials/keras/save_and_load#savedmodel_format https://jax.readthedocs.io/en/latest/export/export.html
I think if we try to describe the artifact_types for JAX and Tensorflow models this would come up again. Therefore I'd like to be more explicit that the artifact_type is the method of how a model was saved. We're already doing this to some extent, with the exceptions being
Old Artifact Type | New Artifact Type |
---|---|
torch.jit.script |
torch.jit.save |
torch.export |
torch.export.save |
torch.compile |
remove in favor of a separate, framework agnostic jit_compiled flag on the model source code asset |
I think this gets around the problem of a .pt file not being very informative. This would also contain information on if a model was AOT compiled or not. for those who find it useful, users can look up the framework specific methods listed here and learn more about them when choosing how to save and describe their models.
we would be better off by just adding a "⚠️ (see best-practice)" next to torch.compile in the table, and provide some explanations of the intended use. Then, users are free to follow that recommendation or not.
I think the MLM README is already too length and difficult to navigate, and I've had this reaction from some folks who have gone through it at Wherobots. IMO the spec is too open to interpretation and we are trying to fill in those gaps with recommendations. Rather than you and I providing more paragraph recommendations to plug gaps, I think we could make the spec more helpful by providing the options to accurately describe a model without ambiguity.
I'm down to do the work here wrt to describing the artifact_type options. And how to represent if the model source code asset has a JIT compile step, this could be a an additional boolean field that indicates the presence of JIT compilation somewhere.
The "content-type" should really be indicated by the type, as detailed in https://github.com/stac-extensions/mlm#model-artifact-media-type. However, because .pt (and its corresponding content media-type), could refer to different (and overlapping) encoding/frameworks/etc., the type alone is not enough.
Ideally I would want there to be clear IANNA types but since there are not I don't see authors of MLM metadata using this if there are no conventional options. And as you noted, it wouldn't be informative enough on it's own to understand how to load the model because of .pt ambiguity.
belatedly coming back to your earlier comment
The examples don't have to be an exhaustive list. The mlm:artifact_type field is not even explicitly defined in the JSON schema at the moment. We should define it as a lax string, similar to other definitions:
agree, I can do this
I think the MLM README is already too length and difficult to navigate
I agree. This is why I proposed adding this details in the best-practices. The full definition and table in the current https://github.com/stac-extensions/mlm#artifact-type-enum could actually be entirely in a specific best-practices section. From the point of view of the https://github.com/stac-extensions/mlm#model-asset, the mlm:artifact_type
is just "some string". The fact that framework function names are use are just to facilitate reuse/comprehension, but are not required.
Framework-specific details shouldn't be explicit at the README level.
framework agnostic
jit_compiled
I think we need to refine this definition. Rather than mlm:jit_compiled
and start having many flags for every combination, I think it is better to have a mlm:compiled_method
with a set of values like "aot"
, "jit"
, null
(default/none applied).
There is also a mlm:compiled
role under https://github.com/stac-extensions/mlm#model-asset that should be better described (or removed since redundant from mlm:compiled_method
?) in the best-practices.
:rocket: Feature Request
Currently we only suggest Pytorch artifact types.
And
torch.compile
is not an artifact type. From the pytorch docshttps://pytorch.org/docs/stable/export.html#existing-frameworks
So I think we should remove torch.compile from the list since no one should be using this to specify a model artifact type.
Here's an initial list that includes more frameworks
Scikit-learn (Python)
TensorFlow (Python)
ONNX (Open Neural Network Exchange) (Language-agnostic)
.onnx
PyTorch (Python)
Other Frameworks
XGBoost (framework specific binary format) LightGBM (framework specific binary format) PMML (Predictive Model Markup Language - XML)
R
.rds, .rda ???
Julia
JLD, JLD2 BSON ???
the above list was partially llm generated so take it with some salt, I can look into and confirm use if we decide to move forward wtih this and provide a more exhaustive set of options.
:sound: Motivation
We'd like users outside of the Pytorch ecosystem to understand how to describe their model artifacts so that it is easier to know what part of a framework should be used to load the model and the different runtime dependencies involved since different artifact types have different runtime dependencies.
Francis raised that the list is currently overly specific to Pytorch and I agree: https://github.com/crim-ca/dlm-extension/pull/2#discussion_r1560193533
:satellite: Alternatives
Keep this part of the spec focused on Pytorch and loose with no more recommendations.
:paperclip: Additional context
I'm down to expand this suggestion. I think in terms of validation, we can be lax on the artifact value type that is passed here.