Closed james77777778 closed 1 year ago
Hi @fchollet
I have updated the standardize_dtype
to give the best performance as far as I know.
This change catched subtle bugs in:
digitize
(numpy backend)pad
and isclose
(torch backend)It is surprising that x.dtype == "int"
is True when the dtype is np.int64
. This results in strange behavior in standardize_dtype
.
>>> import numpy as np
>>> x = np.array([0.0, 1.0, 3.0, 1.6])
>>> bins = np.array([0.0, 3.0, 4.5, 7.0])
>>> np.digitize(x, bins).dtype
dtype('int64')
>>> np.digitize(x, bins).dtype == "int"
True
In torch backend:
pad(..., mode="reflect")
has been updated.
https://github.com/pytorch/pytorch/issues/40763torch.result_type
in isclose
for the consistency@fchollet
I have refactored pad
in torch backend to accommodate the restriction of torch.nn.functional.pad
.
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
Replicate and reflection padding are implemented for padding the last 3 dimensions of a 4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor, or the last dimension of a 2D or 3D input tensor.
In the example below, we can find that reflect
padding is not working when pad_width
is a 5D list even it is a 3D padding. However, it works if we remove the redundant 0
.
>>> x = torch.ones((2, 3, 4, 5, 6))
>>> torch.nn.functional.pad(x, [2, 3, 1, 1, 1, 1, 0, 0, 0, 0], mode="reflect").shape
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NotImplementedError: Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now
>>> torch.nn.functional.pad(x, [2, 3, 1, 1, 1, 1], mode="reflect").shape
torch.Size([2, 3, 6, 7, 11])
>>>
I have also updated the unit test to improve the coverage for various shapes, dtypes and modes.
One of the advantages of Keras Core is that we can integrate the workflow with different backends. For example, we can train a tensorflow model using a torch dataloader.
However, operations containing
standardize_dtype
might fail when the dtype istorch.Tensor.dtype
and the backend is NOT torch.This PR has addressed the issue by implementing a better check for
torch.Tensor.dtype
. A unit test for this behavior has been included.