stac-extensions / mlm

STAC Machine Learning Model (MLM) Extension to describe ML models, their training details, and inference runtime requirements.
https://stac-extensions.github.io/mlm/
Apache License 2.0
5 stars 0 forks source link

Revise and expand artifact types to include more frameworks, remove torch.compile #31

Open rbavery opened 5 days ago

rbavery commented 5 days ago

:rocket: Feature Request

Currently we only suggest Pytorch artifact types.

And torch.compile is not an artifact type. From the pytorch docs

torch.compile() is a JIT compiler whereas which is not intended to be used to produce compiled artifacts outside of deployment.

https://pytorch.org/docs/stable/export.html#existing-frameworks

So I think we should remove torch.compile from the list since no one should be using this to specify a model artifact type.

Here's an initial list that includes more frameworks

Scikit-learn (Python)

TensorFlow (Python)

ONNX (Open Neural Network Exchange) (Language-agnostic)

.onnx

PyTorch (Python)

Other Frameworks

XGBoost (framework specific binary format) LightGBM (framework specific binary format) PMML (Predictive Model Markup Language - XML)

R

.rds, .rda ???

Julia

JLD, JLD2 BSON ???

the above list was partially llm generated so take it with some salt, I can look into and confirm use if we decide to move forward wtih this and provide a more exhaustive set of options.

:sound: Motivation

We'd like users outside of the Pytorch ecosystem to understand how to describe their model artifacts so that it is easier to know what part of a framework should be used to load the model and the different runtime dependencies involved since different artifact types have different runtime dependencies.

Francis raised that the list is currently overly specific to Pytorch and I agree: https://github.com/crim-ca/dlm-extension/pull/2#discussion_r1560193533

:satellite: Alternatives

Keep this part of the spec focused on Pytorch and loose with no more recommendations.

:paperclip: Additional context

I'm down to expand this suggestion. I think in terms of validation, we can be lax on the artifact value type that is passed here.

fmigneault commented 4 days ago

Although torch.compile is not recommended, I wouldn't remove it completely, since some existing artifacts might exist and still depend on it (e.g.: older code that doesn't work with newer methods). Instead, I think we should add a (*) + "WARNING" block under the https://github.com/stac-extensions/mlm?tab=readme-ov-file#artifact-type-enum to "STRONGLY" advise migrating to other more appropriate artifact types. Basically, removing it might cause more problems by obfuscation rather than explicitly mentioning why it is bad practice.

As for the other proposed types, I agree. We need more non-pytorch examples.

The examples don't have to be an exhaustive list. The mlm:artifact_type field is not even explicitly defined in the JSON schema at the moment. We should define it as a lax string, similar to other definitions: https://github.com/stac-extensions/mlm/blob/25bef80362629188c78938277643a2c53659163b/json-schema/schema.json#L325-L366

rbavery commented 4 days ago

Although torch.compile is not recommended, I wouldn't remove it completely, since some existing artifacts might exist and still depend on it (e.g.: older code that doesn't work with newer methods).

I'm not saying that torch.compile is deprecated or anything, I'm saying it was never and currently is not used for saving model artifacts. so it shouldn't be listed as an artifact type. It is a very new and novel tool built for a different purpose: for optimizing eager code, not saving out full model graphs (model artifacts)

I am confident that nobody is producing model artifacts with torch.compile, since it is by design unable to do this. The torch docs state this. torch.export and torch.compile share optimization code paths (Torch Inductor), but only torch.export produces model artifacts.

fmigneault commented 3 days ago

I disagree.

torch.compile can use the same .pt2 as torch.export. The only issue (which torch.export is trying to address) is that torch.compile can break when a certain reference cannot be resolved and trying to fall back to the Python implemention, which might not exist if the package defining it is missing.

Because torch.export is still in prototype feature in development, we should still give the option to use torch.compile. Furthermore, there are still official documentation of pytorch using this exact strategy: https://github.com/pytorch/serve/blob/master/examples/pt2/README.md#torchcompile

rbavery commented 3 days ago

Here's the example you linked from torchserve, which is discussing packaging both model code (that is optimized with torch.compile) and a state_dict (.pt), which is the model artifact file.

# 1. Convert a regular module to an optimized module
opt_mod = torch.compile(mod)
# 2. Train the optimized module
# ....
# 3. Save the opt module state dict
torch.save(opt_model.state_dict(), "model.pt")

# 4. Reload the model
mod = torch.load(model)

# 5. Compile the module and then run inferences with it
opt_mod = torch.compile(mod)

from that torchserve doc:

torchserve takes care of 4 and 5 for you while the remaining steps are your responsibility. You can do the exact same thing on the vast majority of TIMM or HuggingFace models.

In this example, torch.compile needs to be called twice, at train time and inference time. this is because the artifact (model.pt) that is saved does not bake in the optimizations from torch.compile.

therefore, torch.compile is not used to produce model artifacts. the model can be deployed without torch.compile. torch.compile is a detail about the runtime dependencies for modls that depend on eager Pytorch code.

torch.compile can break when a certain reference cannot be resolved and trying to fall back to the Python implemention

this is true but it only happens in eager mode because there's no way to save artifacts that bake in these torch.compile optimizations into a model artifact. to do that, you need torch.export.

rbavery commented 3 days ago

to summarize the above, torch.compile optimizes a model's graph, but doesn't handle saving the model graph + state_dict weights. I think it'd be great to have a place in the mlm extension for calling out what hardware specific optimizations have been applied to an inference pipeline that depends on nn.Module or similar ml framework constructs. But I don't think the model artifact section is the right place to do that.

fmigneault commented 3 days ago

My intuition would be that mlm:accelerator, mlm:accelerator_constrained, etc. would be used in combination with the corresponding state dict generated for a given hardware optimization. Beyond that, I don't think it is the role of MLM anymore to be explicit about each parameter, and something like torch.export that addresses those issues should be used instead if that really matters.

The idea of specifying mlm:artifact_type = torch.compile instead of directly using mlm:artifact_type = torch.save is simply to hint, as you mentioned and as in the example code, that this step was performed before torch.save to obtain the .pt2, and that reloading the model should ideally do torch.load + torch.compile rather than only torch.load. Since the compile step could modify which inference operations actually happen and which results are obtained, such as by affecting digit precisions for example, replicable results need to consider this hint that affects the obtained state dict. In other words, mlm:artifact_type doesn't only indicate the "content-type" of the artifact, but how it is intended to be used as well.

Now, I'm not saying torch.compile is necessarily the "right way", but it is "a way", and I prefer MLM indicating it to guide best practices (and warn against bad-practices for that matter).

rbavery commented 3 days ago

torch.compile doesn't affect the state_dict at all though. it has no bearing on the contents of the file from torch.save. I don't think we should guide users to call torch.compile an artifact type because it has no bearing on the contents of artifacts saved with torch.save.

also it is not the convention to save state_dict weights as ".pt2". of course it is possible but the torch docs state it is convention to only use .pt2 from torch.export. I could call a torch.export file .pt just like I could call a .tiff file a .txt but that's be misleading and go against convention.

fmigneault commented 3 days ago

The same docs also says:

torch.compile() also utilizes the same PT2 stack as torch.export

and

This feature is a prototype under active development and there WILL BE BREAKING CHANGES in the future.

Therefore, the convention is not really well-established. I cannot consciously force users toward torch.export until the feature is stable, and the workaround is currently to use torch.compile that has been available for a long time.

As mentioned previously,

mlm:artifact_type doesn't only indicate the "content-type" of the artifact, but how it is intended to be used as well

Therefore, even if the state_dict would be identical with/without torch.compile, this is a hint telling you explicitly that the reproducible runtime using the provided .pt[h][2] file expects the user to do torch.compile before running model inference, and if not done, results might differ (either in terms of prediction values, runtime speed, etc.). If we could only rely on the file type to infer how to employ the model, we wouldn't need the mlm:artifact_type. This is exactly what https://github.com/stac-extensions/mlm#artifact-type-enum is already mentioning.

Again, I have no issue about adding more references to explanations/recommendations for using a preferred approach over another in the best-practices for better reproducibility, but I do not think this is a sufficient reason to remove torch.compile entirely. There are some scripts out there that still use torch.compile, and they must be able to indicate it somehow. I would be happy to link to any official pytorch guidance in the best-practices document about which approach is better and why to use it, but this explanation should most definitely live in pytorch docs, not in MLM, since the extension intends to be framework-agnostic.

rbavery commented 3 days ago

mlm:artifact_type doesn't only indicate the "content-type" of the artifact, but how it is intended to be used as well

maybe this is the core issue. are we overloading artifact_type with these two concepts: "content-type" and "intended use"? I think so. And I think it will confuse users looking at MLM metadata.

Therefore, even if the state_dict would be identical with/without torch.compile, this is a hint telling you explicitly that the reproducible runtime using the provided .pt[h][2] file expects the user to do torch.compile before running model inference, and if not done, results might differ (either in terms of prediction values, runtime speed, etc.).

this seems like a lot to infer. I'd rather have a separate field that marks what inference optimizations are made in the inference pipeline that don't affect the content type. then it would be clear that the content-type field indicates how to load the model and the other, optimization focused field dictates the suggested inference code path to use at inference time.

I also have no issue with adding more recommendations and I don't want to remove mention of torch.compile. I'm fine even with removing torch.export until the API is solidified and stable. I also don't think torch.compile is out of date or deprecated or it's an either or between torch.compile and torch.export, they serve different purposes. I just don't want users to think that we are hinting that it defines a unique content type.

rbavery commented 3 days ago

My motivation for being more explicit here wrt content-type vs. properties of the inference pipeline is that earlier this year I was very confused going through the Pytorch documentation on how to produce a compiled model artifact that had no runtime dependencies other than Pytorch. I think it is unfortunate that the docs do not make it clear that torch.compile and torch.export are not used for the same purposes even though they share the underlying compile workflow.

I spent a good amount of time thinking torch.compile would produce the model artifact I needed when it wasn't the right tool. I don't want to mislead users into doing the same.

fmigneault commented 3 days ago

overloading artifact_type with these two concepts: "content-type" and "intended use"

Somewhat (and maybe even more than 2 concepts).

The "content-type" should really be indicated by the type, as detailed in https://github.com/stac-extensions/mlm#model-artifact-media-type. However, because .pt (and its corresponding content media-type), could refer to different (and overlapping) encoding/frameworks/etc., the type alone is not enough. This is where the mlm:artifact_type comes in to add more context on top to disambiguate similar uses of the same type. The type and mlm:artifact_type should somewhat align naturally.

Note that https://github.com/stac-extensions/mlm#artifact-type-enum also explicitly indicates that the names like torch.compile/torch.export are used explicitly to give a very obvious hint about what should be done with the artifact file. However, this is an arbitrary value. Using mlm:artifact_type: "created with torch.save and should be loaded using torch.load followed by torch.compile" would be technically just as valid, but much less appealing and convenient to use.

The reason a single mlm:artifact_type property is used is that, in pretty much any other case, we would inject framework-dependant concepts. For example, if we defined some kind of mlm:artifact_compile: true|false property to indicate whether torch.compile is used or not, what happens for frameworks that have no concept "compile"? What if they do, but that refers to something else entirely? Then, MLM definitions are not meaningful anymore in terms of searching/filtering STAC items. I want to avoid having any sort of listing of specific properties from all possible frameworks, because it becomes a maintenance burden and goes against the unifying intention of MLM.

If usage is not obvious from the type and mlm:artifact_type alone, then documentation should be provided, either as a description property, or strait-up an accompanying reference code/script using roles: [mlm:source_code] that makes usage obvious.

Personally, I think this is such a specific edge-case, that we would be better off by just adding a "⚠️ (see best-practice)" next to torch.compile in the table, and provide some explanations of the intended use. Then, users are free to follow that recommendation or not.

rbavery commented 3 days ago

Personally, I think this is such a specific edge-case

Pytorch, Tensorflow, and JAX all provide mechanisms for JIT and AOT compiled models. JIT and AOT models have very different deployment environments and level of effort from those looking to run the models. I think it is important that the MLM capture this variation, it's probably something users even want to search on because "level of effort to try the model out" is often a first order concern.

Examples of AOT and JIT options besides Pytorch https://www.tensorflow.org/tutorials/keras/save_and_load#savedmodel_format https://jax.readthedocs.io/en/latest/export/export.html

I think if we try to describe the artifact_types for JAX and Tensorflow models this would come up again. Therefore I'd like to be more explicit that the artifact_type is the method of how a model was saved. We're already doing this to some extent, with the exceptions being

Old Artifact Type New Artifact Type
torch.jit.script torch.jit.save
torch.export torch.export.save
torch.compile remove in favor of a separate, framework agnostic jit_compiled flag on the model source code asset

I think this gets around the problem of a .pt file not being very informative. This would also contain information on if a model was AOT compiled or not. for those who find it useful, users can look up the framework specific methods listed here and learn more about them when choosing how to save and describe their models.

we would be better off by just adding a "⚠️ (see best-practice)" next to torch.compile in the table, and provide some explanations of the intended use. Then, users are free to follow that recommendation or not.

I think the MLM README is already too length and difficult to navigate, and I've had this reaction from some folks who have gone through it at Wherobots. IMO the spec is too open to interpretation and we are trying to fill in those gaps with recommendations. Rather than you and I providing more paragraph recommendations to plug gaps, I think we could make the spec more helpful by providing the options to accurately describe a model without ambiguity.

I'm down to do the work here wrt to describing the artifact_type options. And how to represent if the model source code asset has a JIT compile step, this could be a an additional boolean field that indicates the presence of JIT compilation somewhere.

The "content-type" should really be indicated by the type, as detailed in https://github.com/stac-extensions/mlm#model-artifact-media-type. However, because .pt (and its corresponding content media-type), could refer to different (and overlapping) encoding/frameworks/etc., the type alone is not enough.

Ideally I would want there to be clear IANNA types but since there are not I don't see authors of MLM metadata using this if there are no conventional options. And as you noted, it wouldn't be informative enough on it's own to understand how to load the model because of .pt ambiguity.

rbavery commented 3 days ago

belatedly coming back to your earlier comment

The examples don't have to be an exhaustive list. The mlm:artifact_type field is not even explicitly defined in the JSON schema at the moment. We should define it as a lax string, similar to other definitions:

agree, I can do this

fmigneault commented 1 day ago

I think the MLM README is already too length and difficult to navigate

I agree. This is why I proposed adding this details in the best-practices. The full definition and table in the current https://github.com/stac-extensions/mlm#artifact-type-enum could actually be entirely in a specific best-practices section. From the point of view of the https://github.com/stac-extensions/mlm#model-asset, the mlm:artifact_type is just "some string". The fact that framework function names are use are just to facilitate reuse/comprehension, but are not required. Framework-specific details shouldn't be explicit at the README level.

framework agnostic jit_compiled

I think we need to refine this definition. Rather than mlm:jit_compiled and start having many flags for every combination, I think it is better to have a mlm:compiled_method with a set of values like "aot", "jit", null (default/none applied).

There is also a mlm:compiled role under https://github.com/stac-extensions/mlm#model-asset that should be better described (or removed since redundant from mlm:compiled_method?) in the best-practices.