t-kalinowski / deep-learning-with-R-2nd-edition-code

Code from the book "Deep Learning with R, 2nd Edition"
https://blogs.rstudio.com/ai/posts/2022-05-31-deep-learning-with-r-2e/
55 stars 23 forks source link

Mini Xception model differences when keras and keras3 packages are installed #16

Closed jonbry closed 6 months ago

jonbry commented 6 months ago

I have noticed a strange issue with the mini Xception model from Chapter 9. When running it on the Linux machine, it runs smoothly with keras but not with keras3. I haven't run into this issue with other examples since the latest 2.15 release of keras and just wanted to see what I may be doing to cause the issue.

The linux machine has both keras 2.15 and keras3 0.2.0 installed. When running the code with keras3, only the keras3 package is attached (terminal is showing r-keras environment). Here's the metric plots using both packages (I restarted Rstudio between each example):

With keras package: mini_xception_keras2

With keras3 package: mini_xception_keras3

The accuracy using keras3 is basically 50% across all epochs, which was strange. I ran the same code on a Mac that only had keras3 installed and got similar results to the model when using the keras package. You can find the sessionInfo() and py_list_packages below for each package on the Linux machine:

Mini Xception-like model with keras

R version 4.2.3 (2023-03-15)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.4 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so

locale:
 [1] LC_CTYPE=en_US.UTF-8      
 [2] LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8       
 [4] LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8   
 [6] LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8      
 [8] LC_NAME=C                 
 [9] LC_ADDRESS=C              
[10] LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8
[12] LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils    
[5] datasets  methods   base     

other attached packages:
[1] keras_2.15.0

loaded via a namespace (and not attached):
 [1] zip_2.3.1             
 [2] Rcpp_1.0.12           
 [3] pillar_1.9.0          
 [4] compiler_4.2.3        
 [5] base64enc_0.1-3       
 [6] tools_4.2.3           
 [7] zeallot_0.1.0         
 [8] nlme_3.1-162          
 [9] jsonlite_1.8.8        
[10] lifecycle_1.0.4       
[11] tibble_3.2.1          
[12] gtable_0.3.4          
[13] lattice_0.20-45       
[14] mgcv_1.8-42           
[15] pkgconfig_2.0.3       
[16] png_0.1-8             
[17] rlang_1.1.3           
[18] Matrix_1.5-3          
[19] cli_3.6.2             
[20] rstudioapi_0.16.0     
[21] withr_3.0.0           
[22] dplyr_1.1.4           
[23] generics_0.1.3        
[24] vctrs_0.6.5           
[25] rprojroot_2.0.4       
[26] grid_4.2.3            
[27] tidyselect_1.2.1      
[28] reticulate_1.36.0     
[29] glue_1.7.0            
[30] here_1.0.1            
[31] R6_2.5.1              
[32] fansi_1.0.6           
[33] farver_2.1.1          
[34] ggplot2_3.5.0         
[35] magrittr_2.0.3        
[36] whisker_0.4.1         
[37] splines_4.2.3         
[38] listarrays_0.4.0      
[39] scales_1.3.0          
[40] tfruns_1.5.2          
[41] colorspace_2.1-0      
[42] labeling_0.4.3        
[43] tensorflow_2.16.0.9000
[44] utf8_1.2.4            
[45] munsell_0.5.1 

> reticulate::py_list_packages()
                        package     version
1                       absl-py       2.1.0
2                  array_record       0.5.1
3                    astunparse       1.6.3
4                    cachetools       5.3.3
5                       certifi    2024.2.2
6            charset-normalizer       3.3.2
7                         click       8.1.7
8                       dm-tree       0.1.8
9                         etils       1.7.0
10                  flatbuffers     24.3.25
11                       fsspec    2024.3.1
12                         gast       0.5.4
13                  google-auth      2.29.0
14         google-auth-oauthlib       1.2.0
15                 google-pasta       0.2.0
16                       grpcio      1.62.2
17                         h5py      3.11.0
18                         idna         3.7
19          importlib_resources       6.4.0
20                        keras      2.15.0
21                     libclang      18.1.1
22                     Markdown         3.6
23                   MarkupSafe       2.1.5
24                    ml-dtypes       0.3.2
25                        numpy      1.26.4
26           nvidia-cublas-cu12    12.2.5.6
27       nvidia-cuda-cupti-cu12    12.2.142
28        nvidia-cuda-nvcc-cu12    12.2.140
29       nvidia-cuda-nvrtc-cu12    12.2.140
30     nvidia-cuda-runtime-cu12    12.2.140
31            nvidia-cudnn-cu12    8.9.4.25
32            nvidia-cufft-cu12  11.0.8.103
33           nvidia-curand-cu12  10.3.3.141
34         nvidia-cusolver-cu12  11.5.2.141
35         nvidia-cusparse-cu12  12.1.2.141
36             nvidia-nccl-cu12      2.16.5
37        nvidia-nvjitlink-cu12    12.2.140
38                     oauthlib       3.2.2
39                   opt-einsum       3.3.0
40                    packaging        24.0
41                       pandas       2.2.2
42                       pillow      10.3.0
43                      promise         2.3
44                     protobuf      3.20.3
45                       psutil       5.9.8
46                       pyasn1       0.6.0
47               pyasn1_modules       0.4.0
48                        pydot       2.0.0
49                    pyparsing       3.1.2
50              python-dateutil 2.9.0.post0
51                         pytz      2024.1
52                     requests      2.31.0
53            requests-oauthlib       2.0.0
54                          rsa         4.9
55                        scipy      1.13.0
56                          six      1.16.0
57                  tensorboard      2.15.2
58      tensorboard-data-server       0.7.2
59                   tensorflow      2.15.1
60          tensorflow-datasets       4.9.4
61         tensorflow-estimator      2.15.0
62               tensorflow-hub      0.16.1
63 tensorflow-io-gcs-filesystem      0.36.0
64          tensorflow-metadata      1.15.0
65                    termcolor       2.4.0
66                     tf_keras      2.15.1
67                         toml      0.10.2
68                         tqdm      4.66.2
69            typing_extensions      4.11.0
70                       tzdata      2024.1
71                      urllib3       2.2.1
72                     Werkzeug       3.0.2
73                        wrapt      1.14.1
74                         zipp      3.18.1
                            requirement
1                        absl-py==2.1.0
2                   array_record==0.5.1
3                     astunparse==1.6.3
4                     cachetools==5.3.3
5                     certifi==2024.2.2
6             charset-normalizer==3.3.2
7                          click==8.1.7
8                        dm-tree==0.1.8
9                          etils==1.7.0
10                 flatbuffers==24.3.25
11                     fsspec==2024.3.1
12                          gast==0.5.4
13                  google-auth==2.29.0
14          google-auth-oauthlib==1.2.0
15                  google-pasta==0.2.0
16                       grpcio==1.62.2
17                         h5py==3.11.0
18                            idna==3.7
19           importlib_resources==6.4.0
20                        keras==2.15.0
21                     libclang==18.1.1
22                        Markdown==3.6
23                    MarkupSafe==2.1.5
24                     ml-dtypes==0.3.2
25                        numpy==1.26.4
26         nvidia-cublas-cu12==12.2.5.6
27     nvidia-cuda-cupti-cu12==12.2.142
28      nvidia-cuda-nvcc-cu12==12.2.140
29     nvidia-cuda-nvrtc-cu12==12.2.140
30   nvidia-cuda-runtime-cu12==12.2.140
31          nvidia-cudnn-cu12==8.9.4.25
32        nvidia-cufft-cu12==11.0.8.103
33       nvidia-curand-cu12==10.3.3.141
34     nvidia-cusolver-cu12==11.5.2.141
35     nvidia-cusparse-cu12==12.1.2.141
36             nvidia-nccl-cu12==2.16.5
37      nvidia-nvjitlink-cu12==12.2.140
38                      oauthlib==3.2.2
39                    opt-einsum==3.3.0
40                      packaging==24.0
41                        pandas==2.2.2
42                       pillow==10.3.0
43                         promise==2.3
44                     protobuf==3.20.3
45                        psutil==5.9.8
46                        pyasn1==0.6.0
47                pyasn1_modules==0.4.0
48                         pydot==2.0.0
49                     pyparsing==3.1.2
50         python-dateutil==2.9.0.post0
51                         pytz==2024.1
52                     requests==2.31.0
53             requests-oauthlib==2.0.0
54                             rsa==4.9
55                        scipy==1.13.0
56                          six==1.16.0
57                  tensorboard==2.15.2
58       tensorboard-data-server==0.7.2
59                   tensorflow==2.15.1
60           tensorflow-datasets==4.9.4
61         tensorflow-estimator==2.15.0
62               tensorflow-hub==0.16.1
63 tensorflow-io-gcs-filesystem==0.36.0
64          tensorflow-metadata==1.15.0
65                     termcolor==2.4.0
66                     tf_keras==2.15.1
67                         toml==0.10.2
68                         tqdm==4.66.2
69            typing_extensions==4.11.0
70                       tzdata==2024.1
71                       urllib3==2.2.1
72                      Werkzeug==3.0.2
73                        wrapt==1.14.1
74                         zipp==3.18.1

Mini Xception with keras3

> sessionInfo()
R version 4.2.3 (2023-03-15)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.4 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=en_US.UTF-8       
 [4] LC_COLLATE=en_US.UTF-8     LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                  LC_ADDRESS=C              
[10] LC_TELEPHONE=C             LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] keras3_0.2.0

loaded via a namespace (and not attached):
 [1] zip_2.3.1              Rcpp_1.0.12            pillar_1.9.0           compiler_4.2.3        
 [5] base64enc_0.1-3        tools_4.2.3            zeallot_0.1.0          nlme_3.1-162          
 [9] jsonlite_1.8.8         lifecycle_1.0.4        tibble_3.2.1           gtable_0.3.4          
[13] lattice_0.20-45        mgcv_1.8-42            pkgconfig_2.0.3        png_0.1-8             
[17] rlang_1.1.3            Matrix_1.5-3           cli_3.6.2              rstudioapi_0.16.0     
[21] fastmap_1.1.1          withr_3.0.0            dplyr_1.1.4            generics_0.1.3        
[25] vctrs_0.6.5            rprojroot_2.0.4        grid_4.2.3             tidyselect_1.2.1      
[29] reticulate_1.36.0      glue_1.7.0             here_1.0.1             R6_2.5.1              
[33] fansi_1.0.6            farver_2.1.1           ggplot2_3.5.0          magrittr_2.0.3        
[37] whisker_0.4.1          splines_4.2.3          listarrays_0.4.0       scales_1.3.0          
[41] tfruns_1.5.2           colorspace_2.1-0       labeling_0.4.3         tensorflow_2.16.0.9000
[45] utf8_1.2.4             munsell_0.5.1         

> reticulate::py_list_packages()
                        package               version                          requirement
1                       absl-py                 2.1.0                       absl-py==2.1.0
2                  array_record                 0.5.1                  array_record==0.5.1
3                     asttokens                 2.4.1                     asttokens==2.4.1
4                    astunparse                 1.6.3                    astunparse==1.6.3
5                        bleach                 6.1.0                        bleach==6.1.0
6                       certifi              2024.2.2                    certifi==2024.2.2
7            charset-normalizer                 3.3.2            charset-normalizer==3.3.2
8                         click                 8.1.7                         click==8.1.7
9                     decorator                 5.1.1                     decorator==5.1.1
10                      dm-tree                 0.1.8                       dm-tree==0.1.8
11                        etils                 1.7.0                         etils==1.7.0
12               exceptiongroup                 1.2.1                exceptiongroup==1.2.1
13                    executing                 2.0.1                     executing==2.0.1
14                  flatbuffers               24.3.25                 flatbuffers==24.3.25
15                       fsspec              2024.3.1                     fsspec==2024.3.1
16                         gast                 0.5.4                          gast==0.5.4
17                 google-pasta                 0.2.0                  google-pasta==0.2.0
18                       grpcio                1.62.2                       grpcio==1.62.2
19                         h5py                3.11.0                         h5py==3.11.0
20                         idna                   3.7                            idna==3.7
21          importlib_resources                 6.4.0           importlib_resources==6.4.0
22                      ipython                8.23.0                      ipython==8.23.0
23                          jax                0.4.26                          jax==0.4.26
24                       jaxlib 0.4.26+cuda12.cudnn89        jaxlib==0.4.26+cuda12.cudnn89
25                         jedi                0.19.1                         jedi==0.19.1
26                       kaggle                1.6.12                       kaggle==1.6.12
27                        keras                 3.3.2                         keras==3.3.2
28                     libclang                18.1.1                     libclang==18.1.1
29                     Markdown                   3.6                        Markdown==3.6
30               markdown-it-py                 3.0.0                markdown-it-py==3.0.0
31                   MarkupSafe                 2.1.5                    MarkupSafe==2.1.5
32            matplotlib-inline                 0.1.7             matplotlib-inline==0.1.7
33                        mdurl                 0.1.2                         mdurl==0.1.2
34                    ml-dtypes                 0.3.2                     ml-dtypes==0.3.2
35                        namex                 0.0.8                         namex==0.0.8
36                        numpy                1.26.4                        numpy==1.26.4
37           nvidia-cublas-cu12              12.3.4.1         nvidia-cublas-cu12==12.3.4.1
38       nvidia-cuda-cupti-cu12              12.3.101     nvidia-cuda-cupti-cu12==12.3.101
39        nvidia-cuda-nvcc-cu12              12.3.107      nvidia-cuda-nvcc-cu12==12.3.107
40       nvidia-cuda-nvrtc-cu12              12.3.107     nvidia-cuda-nvrtc-cu12==12.3.107
41     nvidia-cuda-runtime-cu12              12.3.101   nvidia-cuda-runtime-cu12==12.3.101
42            nvidia-cudnn-cu12              8.9.7.29          nvidia-cudnn-cu12==8.9.7.29
43            nvidia-cufft-cu12             11.0.12.1         nvidia-cufft-cu12==11.0.12.1
44           nvidia-curand-cu12            10.3.4.107       nvidia-curand-cu12==10.3.4.107
45         nvidia-cusolver-cu12            11.5.4.101     nvidia-cusolver-cu12==11.5.4.101
46         nvidia-cusparse-cu12            12.2.0.103     nvidia-cusparse-cu12==12.2.0.103
47             nvidia-nccl-cu12                2.19.3             nvidia-nccl-cu12==2.19.3
48        nvidia-nvjitlink-cu12              12.3.101      nvidia-nvjitlink-cu12==12.3.101
49                   opt-einsum                 3.3.0                    opt-einsum==3.3.0
50                       optree                0.11.0                       optree==0.11.0
51                    packaging                  24.0                      packaging==24.0
52                       pandas                 2.2.2                        pandas==2.2.2
53                        parso                 0.8.4                         parso==0.8.4
54                      pexpect                 4.9.0                       pexpect==4.9.0
55                       pillow                10.3.0                       pillow==10.3.0
56                      promise                   2.3                         promise==2.3
57               prompt-toolkit                3.0.43               prompt-toolkit==3.0.43
58                     protobuf                3.20.3                     protobuf==3.20.3
59                       psutil                 5.9.8                        psutil==5.9.8
60                   ptyprocess                 0.7.0                    ptyprocess==0.7.0
61                    pure-eval                 0.2.2                     pure-eval==0.2.2
62                        pydot                 2.0.0                         pydot==2.0.0
63                     Pygments                2.17.2                     Pygments==2.17.2
64                    pyparsing                 3.1.2                     pyparsing==3.1.2
65              python-dateutil           2.9.0.post0         python-dateutil==2.9.0.post0
66               python-slugify                 8.0.4                python-slugify==8.0.4
67                         pytz                2024.1                         pytz==2024.1
68                     requests                2.31.0                     requests==2.31.0
69                         rich                13.7.1                         rich==13.7.1
70                        scipy                1.13.0                        scipy==1.13.0
71                          six                1.16.0                          six==1.16.0
72                   stack-data                 0.6.3                    stack-data==0.6.3
73                  tensorboard                2.16.2                  tensorboard==2.16.2
74      tensorboard-data-server                 0.7.2       tensorboard-data-server==0.7.2
75                   tensorflow                2.16.1                   tensorflow==2.16.1
76          tensorflow-datasets                 4.9.4           tensorflow-datasets==4.9.4
77 tensorflow-io-gcs-filesystem                0.36.0 tensorflow-io-gcs-filesystem==0.36.0
78          tensorflow-metadata                1.15.0          tensorflow-metadata==1.15.0
79                    termcolor                 2.4.0                     termcolor==2.4.0
80               text-unidecode                   1.3                  text-unidecode==1.3
81                         toml                0.10.2                         toml==0.10.2
82                         tqdm                4.66.2                         tqdm==4.66.2
83                    traitlets                5.14.3                    traitlets==5.14.3
84            typing_extensions                4.11.0            typing_extensions==4.11.0
85                       tzdata                2024.1                       tzdata==2024.1
86                      urllib3                 2.2.1                       urllib3==2.2.1
87                      wcwidth                0.2.13                      wcwidth==0.2.13
88                 webencodings                 0.5.1                  webencodings==0.5.1
89                     Werkzeug                 3.0.2                      Werkzeug==3.0.2
90                        wrapt                1.16.0                        wrapt==1.16.0
91                         zipp                3.18.1                         zipp==3.18.1

Let me know if there is any additional information I can provide to help troubleshoot the issue.

Thank you!

jonbry commented 6 months ago

Here's the plot of the mini Xception model with keras3 using the Mac:

mini_Xception_mac

I forgot to copy the session and package info, but I can run it again if it's helpful. It's crazy that running this model with an old NVIDA graphics card (even one that is ~7 years old) is 4x faster than an M1 Mac. It'll be interesting how this changes with MLX.

t-kalinowski commented 6 months ago

Looking into this now. If I'm understanding correctly, the model fails to train only when on Linux using the GPU, and it trains fine when using the CPU only on Mac?

jonbry commented 6 months ago

model fails to train only when on Linux using the GPU

It only fails when using keras3 on the Linux computer, which should be using the GPU. It works if I use the latest keras package on the Linux computer and when using keras3 on the Mac.

t-kalinowski commented 6 months ago

My best guess right now is that this is an upstream bug related to CUDA/TensorFlow/Keras. Probably, disabling the GPU on Linux (e.g., Sys.setenv(CUDA_VISIBLE_DEVICES="")) would produce the same results seen on the mac.

jonbry commented 6 months ago

My best guess right now is that this is an upstream bug related to CUDA/TensorFlow/Keras. Probably, disabling the GPU on Linux (e.g., Sys.setenv(CUDA_VISIBLE_DEVICES="")) would produce the same results seen on the mac.

Ok, I'll give this a shot. So the bug could affect the keras3 package and not keras?

t-kalinowski commented 6 months ago

I can reproduce on Linux with a GPU:

image

jonbry commented 6 months ago

I can reproduce on Linux with a GPU:

I can't tell you how happy this makes me. I was reading through the keras 3 migration docs last night to see if something may have changed before I realized I could just try it on the Mac. I was very confused when it worked on the Mac, but not with Linux. I'm running it (slowly on the CPU right now and will let you now how it turns out.

t-kalinowski commented 6 months ago

Same behavior with use_backend("jax")

image

jonbry commented 6 months ago

Screenshot from 2024-04-25 14-55-08

I'm calling it at 31 epochs. Same results with keras3/CPU as keras3/GPU. If I run tf$config$list_physical_devices("GPU") and it return list(), does that means it's using CPU for both keras and tensorflow?

t-kalinowski commented 6 months ago

If I run tf$config$list_physical_devices("GPU") and it return list(), does that means it's using CPU for both keras and tensorflow?

Yes (assuming keras is using the tensorflow backend)

t-kalinowski commented 6 months ago

Hmm, using TF 2.16, with {keras} 2 (via Sys.setenv(TF_USE_LEGACY_KERAS = "1"), the issue is still present. I'm starting to suspect the issue is with TF 2.16. Will try TF 2.15 next.

Note to self, needed some code changes to get this working:

Sys.setenv(TF_USE_LEGACY_KERAS = "1")
reticulate::register_class_filter(function(x) {
  if(!is.na(m <- match("keras.src.models.model.Model", x)))
    x <- unique(append(x, "keras.engine.training.Model", after = m))

  if(!is.na(m <- match("keras.src.models.sequential.Sequential", x)))
    x <- unique(append(x, "keras.engine.sequential.Sequential", after = m))

  x
})

load_model <- load_model_tf
dim <- function(x) unlist(x$shape)
jonbry commented 6 months ago

Yep, using tensorflow. On the Mac, I am using tensorflow metal, which is enabled for GPU. Does that mean that running the keras3 model on the Mac was on the GPU rather than the CPU?

t-kalinowski commented 6 months ago

Everything seems to be working fine with R {keras}, TF 2.15, Keras 2.15 (default keras in TF 2.15)

image

t-kalinowski commented 6 months ago

It's noteworthy that training is almost 2x slower in TF 2.16 vs TF 2.15 for this example.

t-kalinowski commented 6 months ago

Running the same code in Python also fails to train the model, this confirms that it's not a bug in the R interface, but somewhere else in the stack: image

Python code (adapted from here)

Code ```python import keras from keras import layers import pathlib from keras.utils import image_dataset_from_directory new_base_dir = pathlib.Path("cats_vs_dogs_small") train_dataset = image_dataset_from_directory( new_base_dir / "train", image_size=(180, 180), batch_size=32) validation_dataset = image_dataset_from_directory( new_base_dir / "validation", image_size=(180, 180), batch_size=32) test_dataset = image_dataset_from_directory( new_base_dir / "test", image_size=(180, 180), batch_size=32) data_augmentation = keras.Sequential( [ layers.RandomFlip("horizontal"), layers.RandomRotation(0.1), layers.RandomZoom(0.2), ] ) inputs = keras.Input(shape=(180, 180, 3)) x = data_augmentation(inputs) x = layers.Rescaling(1./255)(x) x = layers.Conv2D(filters=32, kernel_size=5, use_bias=False)(x) for size in [32, 64, 128, 256, 512]: residual = x x = layers.BatchNormalization()(x) x = layers.Activation("relu")(x) x = layers.SeparableConv2D(size, 3, padding="same", use_bias=False)(x) x = layers.BatchNormalization()(x) x = layers.Activation("relu")(x) x = layers.SeparableConv2D(size, 3, padding="same", use_bias=False)(x) x = layers.MaxPooling2D(3, strides=2, padding="same")(x) residual = layers.Conv2D( size, 1, strides=2, padding="same", use_bias=False)(residual) x = layers.add([x, residual]) x = layers.GlobalAveragePooling2D()(x) x = layers.Dropout(0.5)(x) outputs = layers.Dense(1, activation="sigmoid")(x) model = keras.Model(inputs=inputs, outputs=outputs) model.compile(loss="binary_crossentropy", optimizer="rmsprop", metrics=["accuracy"]) history = model.fit( train_dataset, epochs=100, validation_data=validation_dataset) ```
jonbry commented 6 months ago

Running the same code in Python also fails to train the model, this confirms that it's not a bug in the R interface, but somewhere else in the stack

Since this is happening with TensorFlow and JAX on Linux, regardless on whether it's running on the CPU or GPU, would it make sense for me to open an issue in the keras repository? I went through the keras issues that have been created in the last few months and I didn't see any that related to this issue. It seems most people experience this issue when they use the wrong activation in the last layer (not sigmoid), which isn't the case here.

It also appears to affect this example with Linux using keras3. For some reason it actually runs without throwing the ellipse error but the accuracy is still 50.

t-kalinowski commented 6 months ago

Yes please! I was planning to open an issue today... ~I think it might be related to the serializing callback and tied to https://github.com/rstudio/reticulate/issues/1601.~ Still trying to get the full picture.

If you file an issue upstream, please link back here and I'll add context as needed.

jonbry commented 6 months ago

I just opened an issue with keras: https://github.com/keras-team/keras/issues/19623

jonbry commented 6 months ago

Looks good: image

Since Keras 3.3.3 fixes the issue, would you like me close the issue or keep it open until v3.3.3 gets included in keras3?

Thanks for all of your help getting this resolved!

t-kalinowski commented 6 months ago

@jonbry This report helped flush out two very excellent bugs.

Fix 1: https://github.com/keras-team/keras/commit/5883a25f1b7c6eacc3f21f1821751a4109700796 Fix 2: https://github.com/rstudio/reticulate/pull/1602

A heartfelt Thank You!

Please keep the bug reports coming!

cc: @fchollet