keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

fixes #550: move torch tensors to cpu before to numpy array #853

Closed dkgaraujo closed 11 months ago

dkgaraujo commented 1 year ago

As mentioned in keras-team/keras#18437, one remaining change from keras-team/keras-core#750 is to first move the torch tensor to the CPU before passing it to a numpy array.

This is relevant for the Metal GPU case. While calling .cpu() is harmless on a torch tensor already in the CPU, in my observation directly moving the tensor from Metal to numpy returns an error, hence the need to move it first.

google-cla[bot] commented 1 year ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

dkgaraujo commented 1 year ago

Another thing I would mention is that the instructions for macOS users could be updated to reflect that access to GPU requires setting another environment variable as follows:

import os
os.environ[“KERAS_BACKEND”] = “torch”
os.environ[“PYTORCH_ENABLE_MPS_FALLBACK”] = “1”
import keras_core as keras
codecov[bot] commented 12 months ago

Codecov Report

Patch coverage: 100.00% and no project coverage change.

Comparison is base (f350416) 76.00% compared to head (ae75181) 76.00%. Report is 2 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main keras-team/keras-core#853 +/- ## ======================================= Coverage 76.00% 76.00% ======================================= Files 328 328 Lines 31103 31103 Branches 6052 6052 ======================================= Hits 23639 23639 Misses 5866 5866 Partials 1598 1598 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras-core/pull/853/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras_core](https://app.codecov.io/gh/keras-team/keras-core/pull/853/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `75.91% <100.00%> (ø)` | | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files Changed](https://app.codecov.io/gh/keras-team/keras-core/pull/853?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [...\_core/trainers/data\_adapters/torch\_data\_adapter.py](https://app.codecov.io/gh/keras-team/keras-core/pull/853?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS90cmFpbmVycy9kYXRhX2FkYXB0ZXJzL3RvcmNoX2RhdGFfYWRhcHRlci5weQ==) | `86.95% <100.00%> (ø)` | |

:umbrella: View full report in Codecov by Sentry.

:loudspeaker: Have feedback on the report? Share it here.

dkgaraujo commented 12 months ago

I signed the CLA - I'm not sure if I should take another step to officially pass that check.