Open h-vetinari opened 1 month ago
I looked into the issue. It is a incompatibility issue with torch.compile()
, which is hard to debug.
I will investigate it more.
I can't recreate the failures on conda-forge:
$ mamba list pytorch
# packages in environment at /home/mark/miniforge3/envs/dev:
#
# Name Version Build Channel
pytorch 2.3.0 cpu_mkl_py310h75865b9_101 conda-forge
$ pip list | grep keras
keras 3.3.3 /home/mark/git/keras
tf_keras 2.16.0
pytest -x keras/src/trainers/trainer_test.py::TestTrainer::test_predict_flow_struct_jit
passes. In fact that whole file's tests passes (I can't get all jax, tensorflow, and pytorch to the lastest version installed, so I skipped some tests)
I loosened some tolerances (GPU A6000 seems to fail with the default tolerances) and conditionally skipped jax tests
index a4b21e5f5..7d9a80fec 100644
--- a/keras/src/trainers/trainer_test.py
+++ b/keras/src/trainers/trainer_test.py
@@ -639,7 +639,7 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
model_2.compile(loss="mse", optimizer="adam", steps_per_execution=1)
model_2.fit(x=x, y=y, batch_size=batch_size, verbose=0)
- self.assertAllClose(model.get_weights(), model_2.get_weights())
+ self.assertAllClose(model.get_weights(), model_2.get_weights(), rtol=0.001)
self.assertAllClose(
model.predict(x, batch_size=batch_size),
model_2.predict(x, batch_size=batch_size),
@@ -823,7 +823,7 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
output = model.predict_on_batch(x)
self.assertIsInstance(output, np.ndarray)
- self.assertAllClose(output[0], np.array([3.789511, 3.789511, 3.789511]))
+ self.assertAllClose(output[0], np.array([3.789511, 3.789511, 3.789511]), atol=0.1)
# With sample weights
logs = model.train_on_batch(x, y, sw)
@@ -831,7 +831,7 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
logs = model.test_on_batch(x, y, sw)
self.assertAlmostEqual(logs[0], 14.595)
output = model.predict_on_batch(x)
- self.assertAllClose(output[0], np.array([3.689468, 3.689468, 3.689468]))
+ self.assertAllClose(output[0], np.array([3.689468, 3.689468, 3.689468]), atol=0.1)
# With class weights
logs = model.train_on_batch(x, y, class_weight={1: 0.3, 0: 0.2})
@@ -857,7 +857,7 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
)
output = model.predict_on_batch(x)
self.assertIsInstance(output, np.ndarray)
- self.assertAllClose(output[0], np.array([4.0, 4.0, 4.0]))
+ self.assertAllClose(output[0], np.array([4.0, 4.0, 4.0]), atol=0.1)
logs = model.test_on_batch(x, y)
self.assertIsInstance(logs, list)
@@ -1190,19 +1190,19 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
history = model.fit(
[np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2))
).history
- self.assertAllClose(history["loss"], 16.0)
+ self.assertAllClose(history["loss"], 16.0, atol=0.1)
train_out = model.train_on_batch(
[np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2))
)
- self.assertAllClose(train_out[0], 15.2200)
+ self.assertAllClose(train_out[0], 15.2200, atol=0.1)
eval_out = model.evaluate(
[np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2))
)
- self.assertAllClose(eval_out[0], 13.0321)
+ self.assertAllClose(eval_out[0], 13.0321, atol=0.1)
eval_out = model.test_on_batch(
[np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2))
)
- self.assertAllClose(eval_out[0], 13.0321)
+ self.assertAllClose(eval_out[0], 13.0321, atol=0.1)
predict_out = model.predict([np.ones((3, 2)), np.ones((3, 3))])
self.assertEqual(predict_out.shape, (3, 2))
predict_out = model.predict_on_batch([np.ones((3, 2)), np.ones((3, 3))])
diff --git a/keras/src/utils/backend_utils_test.py b/keras/src/utils/backend_utils_test.py
index ef5c8dfa2..2ed89e8be 100644
--- a/keras/src/utils/backend_utils_test.py
+++ b/keras/src/utils/backend_utils_test.py
@@ -1,5 +1,6 @@
import numpy as np
from absl.testing import parameterized
+import pytest
from keras.src import backend
from keras.src import testing
@@ -29,6 +30,7 @@ class BackendUtilsTest(testing.TestCase, parameterized.TestCase):
y = dynamic_backend.numpy.log10(x)
self.assertIsInstance(y, np.ndarray)
elif name == "jax":
+ pytest.importorskip('jax')
import jax
dynamic_backend.set_backend(name)
Pytorch 2.3 has been out for over a month; the repo here still has https://github.com/keras-team/keras/blob/a243d91e43b4c43fe8d184b541b608b6ddd80f71/requirements.txt#L6-L8
and the referenced issue (#19602) has been closed.
Any plans/timelines for becoming compatible with pytorch 2.3?