pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.35k stars 22.48k forks source link

torch.mps.*Tensor datatypes #82296

Closed TV4Fun closed 1 year ago

TV4Fun commented 2 years ago

πŸš€ The feature, motivation and pitch

An issue that has been debated ad nauseam and apparently still doesn't have an agreed upon answer as of PyTorch 1.12 is how or if to set a default device for Torch operations. See for example #27878, which has been open nearly 3 years now. A method many libraries use and which is recommended in at least a few places on Stack Overflow is to use set_default_tensor_type to default to a CUDA tensor type. This works for CUDA, but does not work for MPS, as there do not appear to be equivalent tensor types for that. This also creates problems with unclear type names, as if, for example, I create a tensor on MPS with w = torch.tensor([1.0], device='mps'), w.type() for this returns 'torch.mps.FloatTensor', but this is not actually a valid type. There is no torch.mps module, and if I try to pass it as a string, say with x = w.type('torch.mps.FloatTensor'), this returns an error that it is an invalid type.

As near as I can tell, there is no way to directly create a tensor on the MPS device without specifying device='mps' on every single call, which not only clutters code but also makes it very brittle if I happen to miss it on one call. Please correct me if I am wrong on this. My particular use case is that I would like to add MPS support to ML-Agents, which already supports CUDA by means of calling torch.set_default_tensor_type(torch.cuda.FloatTensor) if a CUDA device is available. There does not appear to be an equivalent way to do this for MPS and I have no desire to try and track down however many hundreds of tensor creation calls there are in their code. I know there have been calls to deprecate set_default_tensor_type, i.e. #53124, but I would recommend not doing that without providing some other way to provide a default device. In the meantime, I would really like it if I had an easy way to set the default tensor type to torch.mps.FloatTensor.

Alternatives

Provide a way to set the default device for torch.tensor() and similar calls to MPS.

Additional context

See also #260 and probably others.

cc @ezyang @gchanan @zou3519 @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

albanD commented 2 years ago

cc @ezyang

ezyang commented 2 years ago

With TorchFunctionMode we can do this in a few dozen lines of code. I'll put up a PoC later today

soumith commented 2 years ago

@ezyang @pbelevich and I were discussing this in a more general context.

With your PoC, can we for example make the default device be a meta device?

ezyang commented 2 years ago

yes; it will basically be similar to how torchdistx does it

ezyang commented 2 years ago

This is totally untested

import torch
from torch.overrides import TorchFunctionMode

_DEVICE_CONSTRUCTOR = {
    # standard ones
    torch.empty,
    torch.empty_strided,
    torch.empty_quantized,
    torch.ones,
    torch.arange,
    torch.bartlett_window,
    torch.blackman_window,
    torch.eye,
    torch.fft.fftfreq,
    torch.fft.rfftfreq,
    torch.full,
    torch.fill,
    torch.hamming_window,
    torch.hann_window,
    torch.kaiser_window,
    torch.linspace,
    torch.logspace,
    torch.nested_tensor,
    # torch.normal,
    torch.ones,
    torch.rand,
    torch.randn,
    torch.randint,
    torch.randperm,
    torch.range,
    torch.sparse_coo_tensor,
    torch.sparse_compressed_tensor,
    torch.sparse_csr_tensor,
    torch.sparse_csc_tensor,
    torch.sparse_bsr_tensor,
    torch.sparse_bsc_tensor,
    torch.tril_indices,
    torch.triu_indices,
    torch.vander,
    torch.zeros,
    torch.asarray,
    # weird ones
    torch.tensor,
    torch.as_tensor,
}

class DeviceMode(TorchFunctionMode):
    def __init__(self, device):
        self.device = torch.device(device)

    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func in _DEVICE_CONSTRUCTOR:
            if kwargs.get('device') is None:
                kwargs['device'] = self.device
            return func(*args, **kwargs)
        return func(*args, **kwargs)

with DeviceMode(torch.device("meta")):
    print(torch.empty(3))

I'd like to put this into core, but before we can get there, we have to make some policy decisions. For example, if you turn on DeviceMode("mps"), and then someone writes torch.randn(2, device="cpu"), does mps override the CPU device, or does the explicit device win out? It would be nice for a limited set of power users to try this out, make behavior modifications as necessary based on what they observe is necessary in the wild, and then we ship that.

TV4Fun commented 2 years ago

I would definitely suggest an explicit device overriding the default device mode. Though if you think this could cause problems, you could always raise a warning the first time this happens.

TV4Fun commented 2 years ago

Is this in the development branch yet? I'd like to be able to try it out.

ezyang commented 2 years ago

it's definitely in the nightly, and it may also work on the most recent official release

TV4Fun commented 2 years ago

Thank you @ezyang. Due to #78681 and #79337, I am unable to test this on Torch nightly with my M1 mac. Trying this on Torch 1.12.0 gives this error:

TypeError                                 Traceback (most recent call last)
Input In [1], in <cell line: 61>()
     57             return func(*args, **kwargs)
     58         return func(*args, **kwargs)
---> 61 with DeviceMode(torch.device("meta")):
     62     print(torch.empty(3))

File ~/miniforge3/envs/torch-nightly/lib/python3.10/site-packages/torch/utils/_mode_utils.py:28, in _wrap_init.<locals>.wrapped(self, inner, *args, **kwargs)
     25 @functools.wraps(f)
     26 def wrapped(self, *args, inner=undef, **kwargs):
     27     if inner is undef:
---> 28         raise TypeError(
     29             f"missing inner keyword argument; instead of constructing a {meta_init_error_info.mode_class_name} "
     30             f"directly, pass the constructor to push_{meta_init_error_info.mode_name}_mode"
     31         )
     32     self.inner = inner
     33     return f(self, *args, **kwargs)

TypeError: missing inner keyword argument; instead of constructing a TorchDispatchMode directly, pass the constructor to push_torch_dispatch_mode
ezyang commented 2 years ago

Ok you are missing a bugfix that makes the syntax work, try writing the context manager as MyMode.push(device) instead

TV4Fun commented 2 years ago

Okay, thank you, that works and if I create this with

with DeviceMode.push(torch.device("mps")):
    print(torch.empty(3))

the created tensor appears to be on the MPS device. Is there a way to do this without having to move all my code inside of a with statement though? It seems to do nothing if I call DeviceMode.push(torch.device("mps")) or even manually DeviceMode.push(torch.device("mps")).__enter__(). I understand the idea of using a context manager here, but in practice, that still means a lot of work porting code that just assumes you can set a default.

ezyang commented 2 years ago
g = DeviceMode.push(torch.device("mps"))
g.__enter__()
TV4Fun commented 2 years ago

Okay, that works on the simple case. Trying on a more complex example is causing an internal error (again on Torch 1.12.0):

Traceback (most recent call last):
  File "/Users/jcroteau/miniforge3/envs/ml-agents/bin/mlagents-learn", line 33, in <module>
    sys.exit(load_entry_point('mlagents', 'console_scripts', 'mlagents-learn')())
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/learn.py", line 260, in main
    run_cli(parse_command_line())
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/learn.py", line 256, in run_cli
    run_training(run_seed, options, num_areas)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/learn.py", line 132, in run_training
    tc.start_learning(env_manager)
  File "/Users/jcroteau/code/ml-agents/ml-agents-envs/mlagents_envs/timers.py", line 305, in wrapped
    return func(*args, **kwargs)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/trainer_controller.py", line 173, in start_learning
    self._reset_env(env_manager)
  File "/Users/jcroteau/code/ml-agents/ml-agents-envs/mlagents_envs/timers.py", line 305, in wrapped
    return func(*args, **kwargs)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/trainer_controller.py", line 107, in _reset_env
    self._register_new_behaviors(env_manager, env_manager.first_step_infos)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/trainer_controller.py", line 268, in _register_new_behaviors
    self._create_trainers_and_managers(env_manager, new_behavior_ids)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/trainer_controller.py", line 166, in _create_trainers_and_managers
    self._create_trainer_and_manager(env_manager, behavior_id)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/trainer_controller.py", line 142, in _create_trainer_and_manager
    trainer.add_policy(parsed_behavior_id, policy)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/sac/trainer.py", line 352, in add_policy
    self.optimizer = self.create_sac_optimizer()
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/sac/trainer.py", line 333, in create_sac_optimizer
    return TorchSACOptimizer(  # type: ignore
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/trainers/sac/optimizer_torch.py", line 159, in __init__
    torch.log(
  File "/Users/jcroteau/miniforge3/envs/ml-agents/lib/python3.9/site-packages/torch/overrides.py", line 1738, in wrapped
    return f(self, *args, **kwargs)
  File "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/torch_utils/torch.py", line 93, in __torch_function__
    return func(*args, **kwargs)
  File "/Users/jcroteau/miniforge3/envs/ml-agents/lib/python3.9/site-packages/torch/overrides.py", line 1738, in wrapped
    return f(self, *args, **kwargs)
  File "/Users/jcroteau/miniforge3/envs/ml-agents/lib/python3.9/site-packages/torch/overrides.py", line 1831, in __torch_function__
    return func(*args, **kwargs)
RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/OperationUtils.mm":363, please report a bug to PyTorch. Placeholder tensor is empty!

The __torch_function__ in "/Users/jcroteau/code/ml-agents/ml-agents/mlagents/torch_utils/torch.py", line 93 is from your DeviceMode class, which I have put into its own module which is loaded by all of my other modules that use Torch (none of them import torch directly).

This module sets the default device mode with the lines

global _device
global _device_mode
_device = torch.device(device_str)
_device_mode = DeviceMode.push(_device)
_device_mode.__enter__()

The line which is triggering this error is https://github.com/Unity-Technologies/ml-agents/blob/main/ml-agents/mlagents/trainers/sac/optimizer_torch.py#L159, which is calling torch.as_tensor with a simple float array argument. Looking at the call stack, it appears other tensors have already been created on the MPS device successfully, so I am not sure why this particular call is causing problems. I will continue to investigate.

ezyang commented 2 years ago

It's possible torch.as_tensor doesn't actually work with meta tensor. You could stub out the implementation in torch function, bypassing the func call with a func to torch.empty with appropriate types

TV4Fun commented 2 years ago

No, that was not the problem, as torch.as_tensor worked in my toy example, and this error still came up when I replaced the above call with torch.tensor. On closer inspection, the problem was that torch.as_tensor was being called with an empty list, and this internal assertion wasn't actually checking if the argument was empty before raising the error. Using a non-empty list fixed this issue. I am not sure if this bug still exists in the dev build or not.

TV4Fun commented 2 years ago

Confirmed this bug does not exist in the dev build. I was able to implement your code above with a few tweaks here and it appears to have set MPS as the default device. Thank you.

kulinseth commented 2 years ago

@ezyang and @TV4Fun , what are the next steps here?

ezyang commented 2 years ago

we merge it to master πŸ‘€

TV4Fun commented 2 years ago

@kulinseth, I would call @ezyang's solution here a hack. It works, but it would still be nice if there were a simpler way to set the default device to MPS.

ezyang commented 2 years ago

What did you find hacky about it in the end?

TV4Fun commented 2 years ago

Just that it involves creating a long list of each individual tensor constructor and then creating a context manager at global scope and manually calling __enter__() on it. I suppose it could work if you want to integrate it into Torch and create a simple set_default function that user code can call it will work, but the finer details should be internal to Torch.

ezyang commented 2 years ago

Yes, so supposing that PyTorch core maintained the internal implementation details, and we gave a "global state" function API matching the old API, would that be fine?

TV4Fun commented 2 years ago

@ezyang I'd have to see exactly what you are proposing, but that sounds like a better solution.

ezyang commented 2 years ago

Err, it'd be exactly the same code you're running, just in pytorch library so you don't have to see the sausage πŸ˜›

TV4Fun commented 2 years ago

Yeah, that works.

On Wed, Oct 5, 2022, 7:20 PM Edward Z. Yang @.***> wrote:

Err, it'd be exactly the same code you're running, just in pytorch library so you don't have to see the sausage πŸ˜›

β€” Reply to this email directly, view it on GitHub https://github.com/pytorch/pytorch/issues/82296#issuecomment-1269219917, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABD4ECE3VUZUDBGZNIETHE3WBYZNXANCNFSM54YGAFJQ . You are receiving this because you were mentioned.Message ID: @.***>

TV4Fun commented 1 year ago

BTW, if anyone else wants to use this fix, as of #85593, the reference to torch.nested_tensor should be torch.nested.nested_tensor.

Davidvandijcke commented 1 year ago

FYI, I found this feature very useful

bikcrum commented 1 year ago

Putting it all together, I was able to use "mps" device in mlagents](https://github.com/Unity-Technologies/ml-agents) library. I needed to change torch.py code to following:-

ml-agents/mlagents/torch_utils/torch.py


import os

from distutils.version import LooseVersion
import pkg_resources
from mlagents.torch_utils import cpu_utils
from mlagents.trainers.settings import TorchSettings
from mlagents_envs.logging_util import get_logger
from torch.overrides import TorchFunctionMode

logger = get_logger(__name__)

def assert_torch_installed():
    # Check that torch version 1.6.0 or later has been installed. If not, refer
    # user to the PyTorch webpage for install instructions.
    torch_pkg = None
    try:
        torch_pkg = pkg_resources.get_distribution("torch")
    except pkg_resources.DistributionNotFound:
        pass
    assert torch_pkg is not None and LooseVersion(torch_pkg.version) >= LooseVersion(
        "1.6.0"
    ), (
        "A compatible version of PyTorch was not installed. Please visit the PyTorch homepage "
        + "(https://pytorch.org/get-started/locally/) and follow the instructions to install. "
        + "Version 1.6.0 and later are supported."
    )

assert_torch_installed()

# This should be the only place that we import torch directly.
# Everywhere else is caught by the banned-modules setting for flake8
import torch  # noqa I201

torch.set_num_threads(cpu_utils.get_num_threads_to_use())
os.environ["KMP_BLOCKTIME"] = "0"

_device = torch.device("cpu")

_DEVICE_CONSTRUCTOR = {
    # standard ones
    torch.empty,
    torch.empty_strided,
    torch.empty_quantized,
    torch.ones,
    torch.arange,
    torch.bartlett_window,
    torch.blackman_window,
    torch.eye,
    torch.fft.fftfreq,
    torch.fft.rfftfreq,
    torch.full,
    torch.fill,
    torch.hamming_window,
    torch.hann_window,
    torch.kaiser_window,
    torch.linspace,
    torch.logspace,
    # torch.nested_tensor,
    # torch.normal,
    torch.ones,
    torch.rand,
    torch.randn,
    torch.randint,
    torch.randperm,
    torch.range,
    torch.sparse_coo_tensor,
    torch.sparse_compressed_tensor,
    torch.sparse_csr_tensor,
    torch.sparse_csc_tensor,
    torch.sparse_bsr_tensor,
    torch.sparse_bsc_tensor,
    torch.tril_indices,
    torch.triu_indices,
    torch.vander,
    torch.zeros,
    torch.asarray,
    # weird ones
    torch.tensor,
    torch.as_tensor,
}

class DeviceMode(TorchFunctionMode):
    def __init__(self, device):
        super().__init__()
        self.device = torch.device(device)

    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func in _DEVICE_CONSTRUCTOR:
            if kwargs.get('device') is None:
                kwargs['device'] = self.device
            return func(*args, **kwargs)
        return func(*args, **kwargs)

def set_torch_config(torch_settings: TorchSettings) -> None:
    global _device

    if torch_settings.device is None:

        mps_available = False
        try:
            if torch.backends.mps.is_available():
                mps_available = True
        except:
            pass

        if mps_available:
            device_str = "mps"
        elif torch.cuda.is_available():
            device_str = "cuda"
        else:
            device_str = "cpu"
    else:
        device_str = torch_settings.device

    _device = torch.device(device_str)

    if _device.type == "mps":
        DeviceMode.push(_device).__enter__()
    elif _device.type == "cuda":
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    else:
        torch.set_default_tensor_type(torch.FloatTensor)
    print(f"default Torch device: {_device}")

# Initialize to default settings
set_torch_config(TorchSettings(device=None))

nn = torch.nn

def default_device():
    return _device
``
ezyang commented 1 year ago

gonna try to get this in for 2.0

mattiasu96 commented 1 year ago

Is this fixed? Such error propagates also to PyTorch dependent libraries such as Speechbrain (https://github.com/speechbrain/speechbrain/issues/1794)

TV4Fun commented 1 year ago

You have to use the new torch.utils.device_mode to set your default device.

On Tue, Jan 10, 2023 at 12:07 AM Mattia Surricchio @.***> wrote:

Is this fixed? Such error propagates also to PyTorch dependent libraries such as Speechbrain (speechbrain/speechbrain#1794 https://github.com/speechbrain/speechbrain/issues/1794)

β€” Reply to this email directly, view it on GitHub https://github.com/pytorch/pytorch/issues/82296#issuecomment-1376873448, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABD4ECGV4AJJ2TKVVUE2ZS3WRUKF3ANCNFSM54YGAFJQ . You are receiving this because you were mentioned.Message ID: @.***>

mattiasu96 commented 1 year ago

That feature doesn't look like is available yet in PyTorch docs (or my local package installation). So I guess it hasnt been released yet, right?

You have to use the new torch.utils.device_mode to set your default device. … On Tue, Jan 10, 2023 at 12:07 AM Mattia Surricchio @.> wrote: Is this fixed? Such error propagates also to PyTorch dependent libraries such as Speechbrain (speechbrain/speechbrain#1794 <speechbrain/speechbrain#1794>) β€” Reply to this email directly, view it on GitHub <#82296 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABD4ECGV4AJJ2TKVVUE2ZS3WRUKF3ANCNFSM54YGAFJQ . You are receiving this because you were mentioned.Message ID: @.>

ezyang commented 1 year ago

The feature as landed in the PR has some API changes, in particular you can just use torch.device as the context manager, and there's also now torch.set_default_device

mattiasu96 commented 1 year ago

But this change is available as unstable release from master right? It doesn't look like it is available as a stable release, am i wrong?

ezyang commented 1 year ago

No it's never been in stable. However the snippet in this issue is self contained so you can backport it to a sufficiently recent stable (I think 1.13 only)

mattiasu96 commented 1 year ago

No it's never been in stable. However the snippet in this issue is self contained so you can backport it to a sufficiently recent stable (I think 1.13 only)

Do you mean this one right? https://github.com/pytorch/pytorch/issues/82296#issuecomment-1198613567

ezyang commented 1 year ago

yup

chelseas commented 1 year ago

I am still having this issue ://