keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.02k stars 19.48k forks source link

keras with pytorch backend and mps set to default should use an mps generatir in randperm #19436

Closed ralphrmartin closed 6 months ago

ralphrmartin commented 7 months ago

Keras with pytorch backend and mps set to default needs to use an mps generator in randperm

The following code

import os
os.environ["KERAS_BACKEND"] = "torch"

import torch as torch

torch.set_default_device('mps')

import keras
import numpy as np
from keras import layers

# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# Load the data and split it between train and test sets
(xx_train, yy_train), (xx_test, yy_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = xx_train.astype("float32") / 255
x_test = xx_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = torch.from_numpy(np.expand_dims(xx_train, -1))
x_test = torch.from_numpy(np.expand_dims(xx_test, -1))
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# convert class vectors to binary class matrices
y_train = torch.from_numpy(keras.utils.to_categorical(yy_train, num_classes).astype("float32"))
y_test = torch.from_numpy(keras.utils.to_categorical(yy_test, num_classes).astype("float32"))model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)
batch_size = 128
epochs = 15

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

produces the following error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 6
      2 epochs = 15
      4 model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
----> 6 model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

File [~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122](http://localhost:8888/~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py#line=121), in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File [~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/dataloader.py:631](http://localhost:8888/~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/dataloader.py#line=630), in _BaseDataLoaderIter.__next__(self)
    628 if self._sampler_iter is None:
    629     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    630     self._reset()  # type: ignore[call-arg]
--> 631 data = self._next_data()
    632 self._num_yielded += 1
    633 if self._dataset_kind == _DatasetKind.Iterable and \
    634         self._IterableDataset_len_called is not None and \
    635         self._num_yielded > self._IterableDataset_len_called:

File ~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/dataloader.py:674, in _SingleProcessDataLoaderIter._next_data(self)
    673 def _next_data(self):
--> 674     index = self._next_index()  # may raise StopIteration
    675     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    676     if self._pin_memory:

File [~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/dataloader.py:621](http://localhost:8888/~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/dataloader.py#line=620), in _BaseDataLoaderIter._next_index(self)
    620 def _next_index(self):
--> 621     return next(self._sampler_iter)

File [~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/sampler.py:287](http://localhost:8888/~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/sampler.py#line=286), in BatchSampler.__iter__(self)
    285 batch = [0] * self.batch_size
    286 idx_in_batch = 0
--> 287 for idx in self.sampler:
    288     batch[idx_in_batch] = idx
    289     idx_in_batch += 1

File [~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/sampler.py:167](http://localhost:8888/~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/data/sampler.py#line=166), in RandomSampler.__iter__(self)
    165 else:
    166     for _ in range(self.num_samples // n):
--> 167         yield from torch.randperm(n, generator=generator).tolist()
    168     yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]

File ~/Pytorch/venv-Pytorch/lib/python3.12/site-packages/torch/utils/_device.py:77, in DeviceContext.__torch_function__(self, func, types, args, kwargs)
     75 if func in _device_constructors() and kwargs.get('device') is None:
     76     kwargs['device'] = self.device
---> 77 return func(*args, **kwargs)

RuntimeError: Expected a 'mps:0' generator device but found 'cpu'
SuryanarayanaY commented 7 months ago

Hi @ralphrmartin ,

I have tested the code snippet and getting NotImplementedError as per gist.

ralphrmartin commented 7 months ago

I'm not quite sure who needs to do what here. Is this a matter for the mps team? I'm just an end user trying to use this stuff, and I get the error given in my initial report when running on an Apple Silicon MacBook Pro, with the following versions of packages, using Python 3.12.2

absl-py           2.1.0
appnope           0.1.4
asttokens         2.4.1
comm              0.2.2
contourpy         1.2.1
cycler            0.12.1
debugpy           1.8.1
decorator         5.1.1
executing         2.0.1
filelock          3.13.3
fonttools         4.50.0
fsspec            2024.3.1
h5py              3.10.0
ipykernel         6.29.4
ipython           8.23.0
jedi              0.19.1
Jinja2            3.1.3
jupyter_client    8.6.1
jupyter_core      5.7.2
keras             3.1.1
kiwisolver        1.4.5
markdown-it-py    3.0.0
MarkupSafe        2.1.5
matplotlib        3.8.4
matplotlib-inline 0.1.6
mdurl             0.1.2
ml-dtypes         0.3.2
mpmath            1.3.0
namex             0.0.7
nest-asyncio      1.6.0
networkx          3.2.1
numpy             1.26.4
optree            0.11.0
packaging         24.0
parso             0.8.3
pexpect           4.9.0
pillow            10.3.0
pip               24.0
platformdirs      4.2.0
prompt-toolkit    3.0.43
psutil            5.9.8
ptyprocess        0.7.0
pure-eval         0.2.2
Pygments          2.17.2
pyparsing         3.1.2
python-dateutil   2.9.0.post0
pyzmq             25.1.2
rich              13.7.1
six               1.16.0
stack-data        0.6.3
sympy             1.12
torch             2.2.2
torchvision       0.17.2
tornado           6.4
traitlets         5.14.2
typing_extensions 4.10.0
wcwidth           0.2.13
M7Saad commented 7 months ago

Some operations, such as the 'aten::random_' operator, are currently unsupported for the MPS device in the Torch backend. You can find more information about this issue at https://github.com/pytorch/pytorch/issues/77764. As a temporary solution, I recommend setting the environment variable PYTORCH_ENABLE_MPS_FALLBACK. This enables keras to automatically utilize the GPU, you don't need to set the default device in torch.

SuryanarayanaY commented 7 months ago

Hi @ralphrmartin ,

Could you please refer above comment of @M7Saad .Is It seems compatibility issue with Pytorch ?

ralphrmartin commented 7 months ago

Thank you.

SuryanarayanaY commented 7 months ago

Hi @ralphrmartin ,

Could you please confirm whether this issue is with pytorch compatibility? If so whether we can mark it as resolved ? Thanks!

ralphrmartin commented 7 months ago

Setting PYTORCH_ENABLE_MPS_FALLBACK 1 prevents the issue, thanks.

SuryanarayanaY commented 7 months ago

@ralphrmartin ,

Thanks for the response. Can we mark this as closed now?

ralphrmartin commented 7 months ago

I guess so, but maybe the documentation needs updating to prevent other users from tripping over this.

grasskin commented 6 months ago

@ralphrmartin Hi Ralph, looking into this more it seems that PYTORCH_ENABLE_MPS_FALLBACK might have been an experimental flag that is no longer needed. Have you run into this flag in pytorch in general? Specifically, I'm seeing no mention of it here: https://pytorch.org/docs/stable/notes/mps.html.

If so we can remove the flag check from https://github.com/keras-team/keras/blob/63586fa698cad7005f561fcdbb5ce590fb2484b1/keras/src/backend/torch/core.py#L24

ralphrmartin commented 6 months ago

I am lost at this point. Using

Keras: 3.3.2
Torch: 2.3.0

My original comment holds, that if I dont use PYTORCH_ENABLE_MPS_FALLBACK to 1 and I do torch.set_default_device('mps') as suggested at https://pytorch.org/docs/stable/notes/mps.html), Keras falls over as described in my initial message, failing to use an mps generator in randperm.

If I set PYTORCH_ENABLE_MPS_FALLBACK to 1 then the mps device seems to be used to some extent, but I get

UserWarning: The operator 'aten::_foreach_mul_.Scalar' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. 

If I dont dotorch.set_default_device('mps') , then it appears that the mps device is not used.

So, now what?

grasskin commented 6 months ago

Looks like mps is stable enough that we can remove the experimental flag, will submit a separate PR. Thank you for flagging this Ralph.

google-ml-butler[bot] commented 6 months ago

Are you satisfied with the resolution of your issue? Yes No