t-kalinowski / deep-learning-with-R-2nd-edition-code

Code from the book "Deep Learning with R, 2nd Edition"
https://blogs.rstudio.com/ai/posts/2022-05-31-deep-learning-with-r-2e/
54 stars 22 forks source link

Using layer_lambda() with application_preprocess_inputs() #15

Closed jonbry closed 6 months ago

jonbry commented 7 months ago

I am currently working on the feature extraction with data augmentation example in Chapter 8 and I have run into a bit of an issue while adjusting the code to work with the keras3 package.

The original code wither keras:

outputs <- inputs %>%
  data_augmentation() %>%
  imagenet_preprocess_input() %>%
  conv_base() %>%

If I try and change imagenet_preprocess_input() to application_preprocess_inputs() I get the following error when I build the model:

Error in py_call_impl(callable, call_args$unnamed, call_args$named) : 
  TypeError: Cannot serialize object Ellipsis of type <class 'ellipsis'>. To be serializable, a class must implement the `get_config()` method.

It looks like those using the Python book have the same issue, and the were able to fix it using:

x = keras.applications.vgg16.preprocess_input(x)
x = keras.layers.Lambda(lambda x: keras.applications.vgg16.preprocess_input(x))(x)

And then using safe_mode=False for the ModelCheckPoint.

Any thoughts to how I can do this in R? I haven't been able to make any headway in terms of getting it to work. I also saw another option to define a get_config() method for custom objects, but I'm not sure what I would be defining. Let me know if I can provide any additional information.

Thank you!

t-kalinowski commented 7 months ago

Thanks for reporting.

The latest additions to the keras.applications.* module contain only a stub for preprocess_input(). E.g., here is the implementation and docstring for keras.applications.convnext.preprocess_input

@keras_export("keras.applications.convnext.preprocess_input")
def preprocess_input(x, data_format=None):
    """A placeholder method for backward compatibility.

    The preprocessing logic has been included in the convnext model
    implementation. Users are no longer required to call this method to
    normalize the input data. This method does nothing and only kept as a
    placeholder to align the API surface between old and new version of model.

    Args:
        x: A floating point `numpy.array` or a tensor.
        data_format: Optional data format of the image tensor/array. Defaults to
            None, in which case the global setting
            `keras.backend.image_data_format()` is used
            (unless you changed it, it defaults to `"channels_last"`).{mode}

    Returns:
        Unchanged `numpy.array` or tensor.
    """
    return x

https://github.com/keras-team/keras/blob/f7bc67e6c105c116a2ba7f5412137acf78174b1a/keras/applications/convnext.py#L730-L749

It seems to me that users writing new code today should not need to use a separate preprocess_input() function, the preprocessing steps should be managed as part of the model.

I'll take a closer look tomorrow and think about how to update the example.

jonbry commented 6 months ago

For some reason, I cannot get this example to work for me, even with the keras package. I think I must be doing something wrong. Here's the error message, Python Exception Message, sessionInfo(), and py_list_packages():

Error in py_call_impl(callable, call_args$unnamed, call_args$named) : 
  TypeError: Cannot serialize object Ellipsis of type <class 'ellipsis'>. To be serializable, a class must implement 

── Python Exception Message ─────────────────────────
AttributeError: module 'tensorflow.keras' has no attribute 'get_config'

── R Traceback ──────────────────────────────────────
     ▆
  1. └─.rs.rpc.get_help("get_config", "keras", 6L)
  2.   └─impl()
  3.     └─.rs.getHelpFunction(what, from)
  4.       ├─base::tryCatch(eval(call("$", container, name)), error = function(e) NULL)
  5.       │ └─base (local) tryCatchList(expr, classes, parentenv, handlers)
  6.       │   └─base (local) tryCatchOne(expr, names, parentenv, handlers[[1L]])
  7.       │     └─base (local) doTryCatch(return(expr), name, parentenv, handler)
  8.       ├─base::eval(call("$", container, name))
  9.       │ └─base::eval(call("$", container, name))
 10.       ├─`<python.builtin.module>`$get_config
 11.       └─reticulate:::`$.python.builtin.module`(...)
 12.         └─reticulate::py_get_attr(x, name, FALSE)
See `reticulate::py_last_error()$r_trace$full_call` for more details.

> sessionInfo()
R version 4.2.3 (2023-03-15)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.4 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so

locale:
 [1] LC_CTYPE=en_US.UTF-8      
 [2] LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8       
 [4] LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8   
 [6] LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8      
 [8] LC_NAME=C                 
 [9] LC_ADDRESS=C              
[10] LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8
[12] LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets 
[6] methods   base     

other attached packages:
[1] tfdatasets_2.9.0 fs_1.6.3        
[3] keras_2.15.0    

loaded via a namespace (and not attached):
 [1] zip_2.3.1              Rcpp_1.0.12           
 [3] compiler_4.2.3         pillar_1.9.0          
 [5] base64enc_0.1-3        tools_4.2.3           
 [7] zeallot_0.1.0          jsonlite_1.8.8        
 [9] lifecycle_1.0.4        tibble_3.2.1          
[11] gtable_0.3.4           lattice_0.20-45       
[13] pkgconfig_2.0.3        png_0.1-8             
[15] rlang_1.1.3            Matrix_1.5-3          
[17] cli_3.6.2              rstudioapi_0.16.0     
[19] dplyr_1.1.4            generics_0.1.3        
[21] vctrs_0.6.5            grid_4.2.3            
[23] tidyselect_1.2.1       reticulate_1.36.0     
[25] glue_1.7.0             R6_2.5.1              
[27] tfautograph_0.3.2      fansi_1.0.6           
[29] ggplot2_3.5.0          magrittr_2.0.3        
[31] whisker_0.4.1          backports_1.4.1       
[33] scales_1.3.0           listarrays_0.4.0      
[35] tfruns_1.5.2           colorspace_2.1-0      
[37] tensorflow_2.16.0.9000 utf8_1.2.4            
[39] munsell_0.5.1         
> reticulate::py_list_packages()
                        package     version
1                       absl-py       2.1.0
2                  array_record       0.5.1
3                    astunparse       1.6.3
4                    cachetools       5.3.3
5                       certifi    2024.2.2
6            charset-normalizer       3.3.2
7                         click       8.1.7
8                       dm-tree       0.1.8
9                         etils       1.7.0
10                  flatbuffers     24.3.25
11                       fsspec    2024.3.1
12                         gast       0.5.4
13                  google-auth      2.29.0
14         google-auth-oauthlib       1.2.0
15                 google-pasta       0.2.0
16                       grpcio      1.62.2
17                         h5py      3.11.0
18                         idna         3.7
19          importlib_resources       6.4.0
20                        keras      2.15.0
21                     libclang      18.1.1
22                     Markdown         3.6
23                   MarkupSafe       2.1.5
24                    ml-dtypes       0.3.2
25                        numpy      1.26.4
26           nvidia-cublas-cu12    12.2.5.6
27       nvidia-cuda-cupti-cu12    12.2.142
28        nvidia-cuda-nvcc-cu12    12.2.140
29       nvidia-cuda-nvrtc-cu12    12.2.140
30     nvidia-cuda-runtime-cu12    12.2.140
31            nvidia-cudnn-cu12    8.9.4.25
32            nvidia-cufft-cu12  11.0.8.103
33           nvidia-curand-cu12  10.3.3.141
34         nvidia-cusolver-cu12  11.5.2.141
35         nvidia-cusparse-cu12  12.1.2.141
36             nvidia-nccl-cu12      2.16.5
37        nvidia-nvjitlink-cu12    12.2.140
38                     oauthlib       3.2.2
39                   opt-einsum       3.3.0
40                    packaging        24.0
41                       pandas       2.2.2
42                       pillow      10.3.0
43                      promise         2.3
44                     protobuf      3.20.3
45                       psutil       5.9.8
46                       pyasn1       0.6.0
47               pyasn1_modules       0.4.0
48                        pydot       2.0.0
49                    pyparsing       3.1.2
50              python-dateutil 2.9.0.post0
51                         pytz      2024.1
52                     requests      2.31.0
53            requests-oauthlib       2.0.0
54                          rsa         4.9
55                        scipy      1.13.0
56                          six      1.16.0
57                  tensorboard      2.15.2
58      tensorboard-data-server       0.7.2
59                   tensorflow      2.15.1
60          tensorflow-datasets       4.9.4
61         tensorflow-estimator      2.15.0
62               tensorflow-hub      0.16.1
63 tensorflow-io-gcs-filesystem      0.36.0
64          tensorflow-metadata      1.15.0
65                    termcolor       2.4.0
66                     tf_keras      2.15.1
67                         toml      0.10.2
68                         tqdm      4.66.2
69            typing_extensions      4.11.0
70                       tzdata      2024.1
71                      urllib3       2.2.1
72                     Werkzeug       3.0.2
73                        wrapt      1.14.1
74                         zipp      3.18.1
                            requirement
1                        absl-py==2.1.0
2                   array_record==0.5.1
3                     astunparse==1.6.3
4                     cachetools==5.3.3
5                     certifi==2024.2.2
6             charset-normalizer==3.3.2
7                          click==8.1.7
8                        dm-tree==0.1.8
9                          etils==1.7.0
10                 flatbuffers==24.3.25
11                     fsspec==2024.3.1
12                          gast==0.5.4
13                  google-auth==2.29.0
14          google-auth-oauthlib==1.2.0
15                  google-pasta==0.2.0
16                       grpcio==1.62.2
17                         h5py==3.11.0
18                            idna==3.7
19           importlib_resources==6.4.0
20                        keras==2.15.0
21                     libclang==18.1.1
22                        Markdown==3.6
23                    MarkupSafe==2.1.5
24                     ml-dtypes==0.3.2
25                        numpy==1.26.4
26         nvidia-cublas-cu12==12.2.5.6
27     nvidia-cuda-cupti-cu12==12.2.142
28      nvidia-cuda-nvcc-cu12==12.2.140
29     nvidia-cuda-nvrtc-cu12==12.2.140
30   nvidia-cuda-runtime-cu12==12.2.140
31          nvidia-cudnn-cu12==8.9.4.25
32        nvidia-cufft-cu12==11.0.8.103
33       nvidia-curand-cu12==10.3.3.141
34     nvidia-cusolver-cu12==11.5.2.141
35     nvidia-cusparse-cu12==12.1.2.141
36             nvidia-nccl-cu12==2.16.5
37      nvidia-nvjitlink-cu12==12.2.140
38                      oauthlib==3.2.2
39                    opt-einsum==3.3.0
40                      packaging==24.0
41                        pandas==2.2.2
42                       pillow==10.3.0
43                         promise==2.3
44                     protobuf==3.20.3
45                        psutil==5.9.8
46                        pyasn1==0.6.0
47                pyasn1_modules==0.4.0
48                         pydot==2.0.0
49                     pyparsing==3.1.2
50         python-dateutil==2.9.0.post0
51                         pytz==2024.1
52                     requests==2.31.0
53             requests-oauthlib==2.0.0
54                             rsa==4.9
55                        scipy==1.13.0
56                          six==1.16.0
57                  tensorboard==2.15.2
58       tensorboard-data-server==0.7.2
59                   tensorflow==2.15.1
60           tensorflow-datasets==4.9.4
61         tensorflow-estimator==2.15.0
62               tensorflow-hub==0.16.1
63 tensorflow-io-gcs-filesystem==0.36.0
64          tensorflow-metadata==1.15.0
65                     termcolor==2.4.0
66                     tf_keras==2.15.1
67                         toml==0.10.2
68                         tqdm==4.66.2
69            typing_extensions==4.11.0
70                       tzdata==2024.1
71                       urllib3==2.2.1
72                      Werkzeug==3.0.2
73                        wrapt==1.14.1
74                         zipp==3.18.1
> 

> sessionInfo()

version 4.2.3 (2023-03-15)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.4 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so

locale:
 [1] LC_CTYPE=en_US.UTF-8      
 [2] LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8       
 [4] LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8   
 [6] LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8      
 [8] LC_NAME=C                 
 [9] LC_ADDRESS=C              
[10] LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8
[12] LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets 
[6] methods   base     

other attached packages:
[1] tfdatasets_2.9.0 fs_1.6.3        
[3] keras_2.15.0    

loaded via a namespace (and not attached):
 [1] zip_2.3.1              Rcpp_1.0.12           
 [3] compiler_4.2.3         pillar_1.9.0          
 [5] base64enc_0.1-3        tools_4.2.3           
 [7] zeallot_0.1.0          jsonlite_1.8.8        
 [9] lifecycle_1.0.4        tibble_3.2.1          
[11] gtable_0.3.4           lattice_0.20-45       
[13] pkgconfig_2.0.3        png_0.1-8             
[15] rlang_1.1.3            Matrix_1.5-3          
[17] cli_3.6.2              rstudioapi_0.16.0     
[19] dplyr_1.1.4            generics_0.1.3        
[21] vctrs_0.6.5            grid_4.2.3            
[23] tidyselect_1.2.1       reticulate_1.36.0     
[25] glue_1.7.0             R6_2.5.1              
[27] tfautograph_0.3.2      fansi_1.0.6           
[29] ggplot2_3.5.0          magrittr_2.0.3        
[31] whisker_0.4.1          backports_1.4.1       
[33] scales_1.3.0           listarrays_0.4.0      
[35] tfruns_1.5.2           colorspace_2.1-0      
[37] tensorflow_2.16.0.9000 utf8_1.2.4            
[39] munsell_0.5.1    

> reticulate::py_list_packages()
                        package     version
1                       absl-py       2.1.0
2                  array_record       0.5.1
3                    astunparse       1.6.3
4                    cachetools       5.3.3
5                       certifi    2024.2.2
6            charset-normalizer       3.3.2
7                         click       8.1.7
8                       dm-tree       0.1.8
9                         etils       1.7.0
10                  flatbuffers     24.3.25
11                       fsspec    2024.3.1
12                         gast       0.5.4
13                  google-auth      2.29.0
14         google-auth-oauthlib       1.2.0
15                 google-pasta       0.2.0
16                       grpcio      1.62.2
17                         h5py      3.11.0
18                         idna         3.7
19          importlib_resources       6.4.0
20                        keras      2.15.0
21                     libclang      18.1.1
22                     Markdown         3.6
23                   MarkupSafe       2.1.5
24                    ml-dtypes       0.3.2
25                        numpy      1.26.4
26           nvidia-cublas-cu12    12.2.5.6
27       nvidia-cuda-cupti-cu12    12.2.142
28        nvidia-cuda-nvcc-cu12    12.2.140
29       nvidia-cuda-nvrtc-cu12    12.2.140
30     nvidia-cuda-runtime-cu12    12.2.140
31            nvidia-cudnn-cu12    8.9.4.25
32            nvidia-cufft-cu12  11.0.8.103
33           nvidia-curand-cu12  10.3.3.141
34         nvidia-cusolver-cu12  11.5.2.141
35         nvidia-cusparse-cu12  12.1.2.141
36             nvidia-nccl-cu12      2.16.5
37        nvidia-nvjitlink-cu12    12.2.140
38                     oauthlib       3.2.2
39                   opt-einsum       3.3.0
40                    packaging        24.0
41                       pandas       2.2.2
42                       pillow      10.3.0
43                      promise         2.3
44                     protobuf      3.20.3
45                       psutil       5.9.8
46                       pyasn1       0.6.0
47               pyasn1_modules       0.4.0
48                        pydot       2.0.0
49                    pyparsing       3.1.2
50              python-dateutil 2.9.0.post0
51                         pytz      2024.1
52                     requests      2.31.0
53            requests-oauthlib       2.0.0
54                          rsa         4.9
55                        scipy      1.13.0
56                          six      1.16.0
57                  tensorboard      2.15.2
58      tensorboard-data-server       0.7.2
59                   tensorflow      2.15.1
60          tensorflow-datasets       4.9.4
61         tensorflow-estimator      2.15.0
62               tensorflow-hub      0.16.1
63 tensorflow-io-gcs-filesystem      0.36.0
64          tensorflow-metadata      1.15.0
65                    termcolor       2.4.0
66                     tf_keras      2.15.1
67                         toml      0.10.2
68                         tqdm      4.66.2
69            typing_extensions      4.11.0
70                       tzdata      2024.1
71                      urllib3       2.2.1
72                     Werkzeug       3.0.2
73                        wrapt      1.14.1
74                         zipp      3.18.1
                            requirement
1                        absl-py==2.1.0
2                   array_record==0.5.1
3                     astunparse==1.6.3
4                     cachetools==5.3.3
5                     certifi==2024.2.2
6             charset-normalizer==3.3.2
7                          click==8.1.7
8                        dm-tree==0.1.8
9                          etils==1.7.0
10                 flatbuffers==24.3.25
11                     fsspec==2024.3.1
12                          gast==0.5.4
13                  google-auth==2.29.0
14          google-auth-oauthlib==1.2.0
15                  google-pasta==0.2.0
16                       grpcio==1.62.2
17                         h5py==3.11.0
18                            idna==3.7
19           importlib_resources==6.4.0
20                        keras==2.15.0
21                     libclang==18.1.1
22                        Markdown==3.6
23                    MarkupSafe==2.1.5
24                     ml-dtypes==0.3.2
25                        numpy==1.26.4
26         nvidia-cublas-cu12==12.2.5.6
27     nvidia-cuda-cupti-cu12==12.2.142
28      nvidia-cuda-nvcc-cu12==12.2.140
29     nvidia-cuda-nvrtc-cu12==12.2.140
30   nvidia-cuda-runtime-cu12==12.2.140
31          nvidia-cudnn-cu12==8.9.4.25
32        nvidia-cufft-cu12==11.0.8.103
33       nvidia-curand-cu12==10.3.3.141
34     nvidia-cusolver-cu12==11.5.2.141
35     nvidia-cusparse-cu12==12.1.2.141
36             nvidia-nccl-cu12==2.16.5
37      nvidia-nvjitlink-cu12==12.2.140
38                      oauthlib==3.2.2
39                    opt-einsum==3.3.0
40                      packaging==24.0
41                        pandas==2.2.2
42                       pillow==10.3.0
43                         promise==2.3
44                     protobuf==3.20.3
45                        psutil==5.9.8
46                        pyasn1==0.6.0
47                pyasn1_modules==0.4.0
48                         pydot==2.0.0
49                     pyparsing==3.1.2
50         python-dateutil==2.9.0.post0
51                         pytz==2024.1
52                     requests==2.31.0
53             requests-oauthlib==2.0.0
54                             rsa==4.9
55                        scipy==1.13.0
56                          six==1.16.0
57                  tensorboard==2.15.2
58       tensorboard-data-server==0.7.2
59                   tensorflow==2.15.1
60           tensorflow-datasets==4.9.4
61         tensorflow-estimator==2.15.0
62               tensorflow-hub==0.16.1
63 tensorflow-io-gcs-filesystem==0.36.0
64          tensorflow-metadata==1.15.0
65                     termcolor==2.4.0
66                     tf_keras==2.15.1
67                         toml==0.10.2
68                         tqdm==4.66.2
69            typing_extensions==4.11.0
70                       tzdata==2024.1
71                       urllib3==2.2.1
72                      Werkzeug==3.0.2
73                        wrapt==1.14.1
74                         zipp==3.18.1
> 

Let me know if there is any additional information that I can provide to help troubleshoot the issue.

Thank you!

jonbry commented 6 months ago

I was able to get this example to work with the keras package (2.15) by changing the file type to .h5. Since .h5 is a legacy format, keep looking to see how to make it work with .keras

jonbry commented 6 months ago

It appears that keras v3.3.3 also fixes this issue as well. I think people who want to use keras will need to switch the checkpoint file format to .h5.