keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 117 forks source link

Improve the flexibility of `standardize_dtype` and fix `pad` in torch backend #828

Closed james77777778 closed 1 year ago

james77777778 commented 1 year ago

One of the advantages of Keras Core is that we can integrate the workflow with different backends. For example, we can train a tensorflow model using a torch dataloader.

However, operations containing standardize_dtype might fail when the dtype is torch.Tensor.dtype and the backend is NOT torch.

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import torch

from keras_core import ops

x = torch.randn(4, 16, 16, 3)
y = ops.convert_to_tensor(x, dtype=x.dtype)  # failed w/o this PR
print(y.dtype)

This PR has addressed the issue by implementing a better check for torch.Tensor.dtype. A unit test for this behavior has been included.

james77777778 commented 1 year ago

Hi @fchollet I have updated the standardize_dtype to give the best performance as far as I know.

This change catched subtle bugs in:

It is surprising that x.dtype == "int" is True when the dtype is np.int64. This results in strange behavior in standardize_dtype.

>>> import numpy as np
>>> x = np.array([0.0, 1.0, 3.0, 1.6])
>>> bins = np.array([0.0, 3.0, 4.5, 7.0])
>>> np.digitize(x, bins).dtype
dtype('int64')
>>> np.digitize(x, bins).dtype == "int"
True

In torch backend:

james77777778 commented 1 year ago

@fchollet

I have refactored pad in torch backend to accommodate the restriction of torch.nn.functional.pad. https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html

Replicate and reflection padding are implemented for padding the last 3 dimensions of a 4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor, or the last dimension of a 2D or 3D input tensor.

In the example below, we can find that reflect padding is not working when pad_width is a 5D list even it is a 3D padding. However, it works if we remove the redundant 0.

>>> x = torch.ones((2, 3, 4, 5, 6))
>>> torch.nn.functional.pad(x, [2, 3, 1, 1, 1, 1, 0, 0, 0, 0], mode="reflect").shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now
>>> torch.nn.functional.pad(x, [2, 3, 1, 1, 1, 1], mode="reflect").shape
torch.Size([2, 3, 6, 7, 11])
>>>

I have also updated the unit test to improve the coverage for various shapes, dtypes and modes.