pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
BSD 3-Clause "New" or "Revised" License
3.56k stars 291 forks source link

[RFC] Adding support for non-CUDA devices #1053

Open sanchitintel opened 1 month ago

sanchitintel commented 1 month ago

Motivation

Please advise how to cleanly add support for other devices such as CPU. Right now, device is read from configs, and the built-in configs have been hardcoded for CUDA.

Proposed solution

Duplicating built-in config yaml files for another device (everything else copied verbatim from the original config) doesn't seem like a good idea. Adding a class member supported_devices to Recipe would send across the message that all built-in Configs for a recipe are supported by supported_devices, which may not be true. For example, paged optimizers are currently unsupported on CPU (not sure if CPU could ever support it, as it seems to be GPU-only), so device support should continue to be at the granularity of individual configs.

Please advise if the following approach would be acceptable instead. Similar to the current implementation, it also checks device support at the granularity of a config, but would not require hard-coding device names in built-in configs' yaml files (they would be present in a centralized manner in _recipe_registry.py instead):

  1. Add a List[str] class member supported_devices in Config.
  2. Accept device as a CLI runtime argument
  3. At runtime, check if the provided command line argument for device matches entries in supported_devices if the config is a built-in config. If no command-line argument for device is provided, use the default device (CUDA). Otherwise, if the recipe/config was not one of the built-in recipes/configs, then simply use the device provided as a runtime argument.
  4. Modify some other torchtune code to also support CPU wherever only CUDA is currently being supported, but CPU could be supported as well.

Please advise if this approach seems acceptable. I could submit a PR for it, also including any modifications that you'd suggest.

Thank you!

ebsmothers commented 1 month ago

Hi @sanchitintel thanks for creating this issue, I think this is an important area for us to address.

My main question here is around point (1). For what it's worth associating some extra metadata with the Config and Recipe dataclasses for validation purposes is something I've been thinking about for other applications too, so it's definitely not outside the realm of possibility. But a couple counterpoints:

Anyways I am still open to the approach, but my main concern is that (1) may be necessary but not sufficient for device validation and we are sacrificing on configurability and CLI generality to achieve it. The alternative would be to rely on a separate validation utility not tied to the config dataclass, but instead handled inside the recipe (i.e. based on individual config fields instead of the config as a whole). This would basically mean extending this existing utility. Would be interested to hear your thoughts on this as an alternative.

(4) sounds great, certainly no concerns there!

sanchitintel commented 1 month ago

Hi @ebsmothers, thanks a lot for your feedback!

I now understand that ensuring CLI generality is a goal for the library.

I wonder if supported devices is purely a config property

Good point, thanks! I think while the version of PyTorch which added support for a particular dtype on a device could continue to remain hard-coded in the code, in the future, config yamls may also need to (optionally) specify the oldest PyTorch version which supported a particular feature that the recipe/config requires.

For such changes, as you mentioned, the existing validation utility could be extended instead.

The alternative would be to rely on a separate validation utility not tied to the config dataclass, but instead handled inside the recipe (i.e. based on individual config fields instead of the config as a whole)

Thanks for for advice! Please clarify if this approach would entail built-in config yamls continuing to specify device. If yes, then with such an implementation, would users still be expected to manually modify config yamls to specify a different device than the one in a built-in config?

ebsmothers commented 1 month ago

config yamls may also need to (optionally) specify the oldest PyTorch version which supported a particular feature that the recipe/config requires

You read my mind.. this is actually exactly the case I was thinking of when I mentioned "other applications" in my comment. In this case I agree that we will probably want to attach some extra metadata to the recipe and/or config.

Please clarify if this approach would entail built-in config yamls continuing to specify device. If yes, then with such an implementation, would users still be expected to manually modify config yamls to specify a different device than the one in a built-in config?

Yes it would entail specifying device in config, but users do not have to modify the config yaml to override it. Instead they can override it via CLI (as with any config field) via e.g.

tune run lora_finetune_single_device --config llama3/8B_lora_single_device device=cpu

Let me know if that makes sense to you @sanchitintel and thanks for the helpful discussion here!

sanchitintel commented 1 month ago

Thanks again for clarifying, @ebsmothers!

can override it via CLI (as with any config field)

Oh, sorry, I didn't realize such overrides were already possible!

I'll submit a PR for enabling CPU device for some recipes/configs that could be supported on CPU, but currently aren't, and will extend the existing validation utility, wherever appropriate. Thanks!