pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.18k stars 403 forks source link

[RFC] Kaggle Model Hub Integration for Torchtune #1852

Open KeijiBranshi opened 3 days ago

KeijiBranshi commented 3 days ago

Authors: @keijibranshi @rosbo @mrisdal @neshdev at.bflynn

Summary

This RFC proposes extending torchtune to support loading pre-trained and fine-tuned model weights directly from the Kaggle Model Hub. This integration aims to expand the accessibility of models within torchtune and contribute to the adoption of both PyTorch/torchtune and the Kaggle Model Hub by streamlining the experience for Kaggle users.

Motivation

This proposal aligns with PyTorch's objective of integrating with partner platforms to increase Torchtune adoption (KR 3.2). By adding support for Kaggle, we can:

Other Potential Benefits

Prior Art

Similar functionality for loading models from various hubs exists in other deep learning libraries. Keras, for example, provides a unified mechanism for loading models from both Hugging Face and Kaggle using URI schemes (see the Keras documentation).

Proposed Implementation

We propose extending torchtune's model loading mechanism to recognize and handle Kaggle Model Hub URIs. This will involve the following:

  1. URI Scheme Recognition: Torchtune can be updated to recognize model URIs using the kaggle:// scheme. While Hugging Face will remain the default source for models, we could also add support for explicit Hugging Face URIs using the hf:// scheme for increased clarity.

  2. Kaggle Hub Integration: Leverage the kagglehub Python library to handle the download and upload of model weights to and from the Kaggle Model Hub.

Using the above, we would modify torchtune's model loading logic to:

  1. Detect the URI scheme (kaggle:// or hf://).
  2. Utilize kagglehub for downloading weights from Kaggle when a kaggle:// URI is provided.
  3. Maintain the existing Hugging Face integration for models without a URI scheme or those explicitly using hf://.

Example Usage:

Users will be able to download a model from Kaggle using a command like:

tune download kaggle://metaresearch/llama-3.2/pyTorch/3b \
--output-dir /tmp/llama-3.2-3b \
--kaggle-username <KAGGLE_USERNAME> \
--kaggle-api-key <KAGGLE_API_KEY>

Considerations

Call for Feedback

We’d love feedback from the PyTorch community on this proposal. Please share your thoughts, suggestions, and any potential concerns you may have.

Happy modeling, The Kaggle Team

joecummings commented 3 days ago

Thanks for the RFC @KeijiBranshi - we're really excited about providing another source for model integration. I'm still closely considering some of the implementation details, so I'll respond to those individually.

joecummings commented 3 days ago

Weight formats:

Kaggle supports loading models in checkpointing formats that are able to be loaded into different modeling libraries. These are often differentiated by the names "PyTorch" or "transformers" (Hugging Face). Until now, our only model integration source has been the Hugging Face Hub or from Meta directly, which means that we have guarantees that our models can load checkpoints in the "transformers" format or the Meta Llama format. We make no guarantees about any other checkpoint formats.

At first glance, what this practically means is that if a user tries to load in e.g. Gemma 2 using the following path: google/gemma-2-2b-jpn-it/flax, we should throw an error that "flax" is not supported. Where this gets potentially a little trickier is with respect to the "PyTorch" format. For Llama models, this format designates the aforementioned Meta Llama format that we already support. But for e.g. Gemma models, this is the native format released by Google, which means we would have to write additional logic to convert into our torchtune model format.

On a longer time scale, I'd like to be able to support loading in checkpoints in both the format they were released in + the transformers format, but we don't have the bandwidth to do that right now. So practically, we'd want to throw an error if someone tries to load a model in any format other than "transformers" unless the organization is "metaresearch", in which case we would also support "PyTorch".

Please let me know if I'm missing some details here or if this is too restrictive.

joecummings commented 3 days ago

URI Scheme Recognition:

I think I have a slight preference for a UX similar to llama-stack wherein the source is specified as a param e.g.:

tune download metaresearch/llama-3.2/pyTorch/3b --source kaggle

The reason for this is that the "path" for models on the Hugging Face Hub is very recognizable as the entry point to downloading any of their models e.g.

AutoModel.from_pretrained("openai/whisper-large-v3-turbo")

or

huggingface-cli download openai/whisper-large-v3-turbo

Attaching a URI of "hf" would slightly obfuscate that recognition. In addition, without a prefix, there would be no need for more complex URI scheme handling.

Open to thoughts.

joecummings commented 3 days ago

--output-dir Argument:

We're considering changes to our current download process + checkpointing API that would default to downloading the model to the source's cache as a default. So it should be no problem to make --output-dir optional.

KeijiBranshi commented 3 days ago

Thanks for the thoughtful feedback! We appreciate you taking the time to review our proposal and provide your insights.

KeijiBranshi commented 3 days ago

Weight formats:

Kaggle supports loading models in checkpointing formats that are able to be loaded into different modeling libraries. These are often differentiated by the names "PyTorch" or "transformers" (Hugging Face). Until now, our only model integration source has been the Hugging Face Hub or from Meta directly, which means that we have guarantees that our models can load checkpoints in the "transformers" format or the Meta Llama format. We make no guarantees about any other checkpoint formats.

At first glance, what this practically means is that if a user tries to load in e.g. Gemma 2 using the following path: google/gemma-2-2b-jpn-it/flax, we should throw an error that "flax" is not supported. Where this gets potentially a little trickier is with respect to the "PyTorch" format. For Llama models, this format designates the aforementioned Meta Llama format that we already support. But for e.g. Gemma models, this is the native format released by Google, which means we would have to write additional logic to convert into our torchtune model format.

On a longer time scale, I'd like to be able to support loading in checkpoints in both the format they were released in + the transformers format, but we don't have the bandwidth to do that right now. So practically, we'd want to throw an error if someone tries to load a model in any format other than "transformers" unless the organization is "metaresearch", in which case we would also support "PyTorch".

Please let me know if I'm missing some details here or if this is too restrictive.

RE: Weight Formats https://github.com/pytorch/torchtune/issues/1852#issuecomment-2417624558

Agreed that focusing on valid PyTorch and Transformers formats is a good first step. Filtering out incompatible frameworks (e.g. flax) with string manipulation seems straightforward. But excluding PyTorch downloads to just the Metaresearch models might yield an awkward experience. Namely, when a user publishes a torchtune fine-tuned model to Kaggle, they would not be able to download their own model later using torchtune.

Should we instead consider doing some post-download validation? In other words, download the model payload, but have torchtune check that the files are properly formatted before proceeding?

Open to discussing these options and finding the best approach that balances initial simplicity with long-term flexibility.

KeijiBranshi commented 3 days ago

URI Scheme Recognition:

I think I have a slight preference for a UX similar to llama-stack wherein the source is specified as a param e.g.:

tune download metaresearch/llama-3.2/pyTorch/3b --source kaggle

The reason for this is that the "path" for models on the Hugging Face Hub is very recognizable as the entry point to downloading any of their models e.g.

AutoModel.from_pretrained("openai/whisper-large-v3-turbo")

or

huggingface-cli download openai/whisper-large-v3-turbo

Attaching a URI of "hf" would slightly obfuscate that recognition. In addition, without a prefix, there would be no need for more complex URI scheme handling.

Open to thoughts.

RE: URI Scheme Recognition https://github.com/pytorch/torchtune/issues/1852#issuecomment-2417677159

We're happy to defer to your expertise on the use of the --source parameter, especially if it helps make the UX more consistent across similar libraries.

FWIW, HuggingFace supports the hf:// scheme in some of its own tooling (see documentation). But I understand that it’s not a universal concept for model URIs.

KeijiBranshi commented 3 days ago

--output-dir Argument:

We're considering changes to our current download process + checkpointing API that would default to downloading the model to the source's cache as a default. So it should be no problem to make --output-dir optional.

RE: --output-dir Argument https://github.com/pytorch/torchtune/issues/1852#issuecomment-2417682770

Thanks for sharing those considerations around --output-dir. Currently, kagglehub doesn’t allow users to change the download directory, but we have some related requests on our side to support it.

joecummings commented 1 day ago

RE: Weight Formats #1852 (comment)

Agreed that focusing on valid PyTorch and Transformers formats is a good first step. Filtering out incompatible frameworks (e.g. flax) with string manipulation seems straightforward. But excluding PyTorch downloads to just the Metaresearch models might yield an awkward experience. Namely, when a user publishes a torchtune fine-tuned model to Kaggle, they would not be able to download their own model later using torchtune.

Should we instead consider doing some post-download validation? In other words, download the model payload, but have torchtune check that the files are properly formatted before proceeding?

Open to discussing these options and finding the best approach that balances initial simplicity with long-term flexibility.

Ahh, this is a great consideration. Post-download validation will be fine.

joecummings commented 1 day ago

RE: --output-dir Argument #1852 (comment)

Thanks for sharing those considerations around --output-dir. Currently, kagglehub doesn’t allow users to change the download directory, but we have some related requests on our side to support it.

I think it would be a slightly more consistent UX is torchtune users could specify a custom output directory for either Kaggle or the Hugging Face Hub.

But I don't view this as a blocker to this integration.