keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.27k stars 19.38k forks source link

keras3 doesn't respect explicit `.to` calls on inputs and model params with torch backend. #19392

Open Krovatkin opened 3 months ago

Krovatkin commented 3 months ago

test

The following test that's supposed to run the model below on CPU rather than CUDA

import os

os.environ["KERAS_BACKEND"] = "torch"

import torch
import keras
import numpy as np

# simple MNIST model
def get_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

def _mp_fn(_):
    device = 'cpu'
    model = get_model()
    model.to(device=device)
    inputs = torch.rand(32, 784, device=device)
    outputs = model(inputs)
    print(outputs)

_mp_fn(None)

error

errors out with the following error:

Traceback (most recent call last):
  File "/home/villedepommes/projects/python/keras_torch/luke2.py", line 27, in <module>
    _mp_fn(None)
  File "/home/villedepommes/projects/python/keras_torch/luke2.py", line 24, in _mp_fn
    outputs = model(inputs)
  File "/home/villedepommes/miniconda3/envs/keras_torch/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/villedepommes/miniconda3/envs/keras_torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/villedepommes/miniconda3/envs/keras_torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/villedepommes/miniconda3/envs/keras_torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/villedepommes/miniconda3/envs/keras_torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
RuntimeError: Exception encountered when calling Dense.call().

Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

Env:

Python 3.10.13
>>> torch.__version__
'2.2.1'
>>> keras.__version__
'3.1.1'

Expected Output

The expected behaviour is that the test runs without any errors and prints out an output tensor.

More details:

If I add a pre_forward hook to print out the devices of inputs and parameters I'm seeing that Functional seems to convert my CPU input back to cuda device(type='cpu')] -> device='cuda:0' so the first Dense layer is seeing a cuda tensor.

    def output_hook(module, input):
        device = "keras" if not hasattr(input[0], "device") else str(input[0].device)
        data_ptrs = ([x.data_ptr() for x in module.parameters()])
        parms = [x.device for x in module.parameters()]
        print(f"{module=} {parms=} {device=} {input[0].data_ptr()=} {data_ptrs=}") # {input.device=} {output.device=}

    torch.nn.modules.module.register_module_forward_pre_hook(output_hook)
module=<Functional name=functional_1, built=True> parms=[device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu')] device='cpu' input[0].data_ptr()=1083792256 data_ptrs=[1083570560, 1083771776, 1083772160, 1083788672, 1083789056, 1083554240]
module=<Dense name=dense, built=True> parms=[device(type='cpu'), device(type='cpu')] device='cuda:0' input[0].data_ptr()=139952023142400 data_ptrs=[1083570560, 1083771776]
SuryanarayanaY commented 3 months ago

Hi @Krovatkin ,

I have tested the code in colab and it seems working in coilab environment. The installed versions of torch and keras are 2.2.1+cu121 and 3.1.1 respectively. Please refer to attached gist.

Krovatkin commented 3 months ago

@SuryanarayanaY did you change the runtime to T4?

I re-ran yr notebook and it crashed with the exact same error:

image

SuryanarayanaY commented 3 months ago

Hi @Krovatkin ,

Thanks for letting me know. I have checked with GPU runtime and its raising error as reported. Attached gist for reference. The error seems generated from torch library. Not sure whether this is issue with compatibility with Interface.

Escalating to Dev team for their review and comments.

Krovatkin commented 3 months ago

The error seems generated from torch library

@SuryanarayanaY thank you very much for verifying the test!

The error seems generated from torch library

if you would like, I can provide the pytorch version of this test case which avoids this error. Keras3's Functional moves the input back to a default device ("cuda:0" in this case) from "cpu" causing the error and Functional isn't a torch class.