rstudio / keras3

R Interface to Keras
https://keras3.posit.co/
Other
838 stars 283 forks source link

How to use custom keras models #577

Closed philippmuench closed 6 years ago

philippmuench commented 6 years ago

I would like to replace the layer_lstm() with a multiplicative LSTM implementation (e.g. from here). I'm not sure which files I have to change, it looks like the R/layers-recurrent.R should be changed and I have to put the multiplicative_lstm.py somewhere to the keras folder but I have trouble finding the right locations (on Ubuntu using virtualenv). Maybe someone can point me to the file paths I should change? Or is there a other way to allow such custom modules, e.g. using reticulate? Thanks!

skeydan commented 6 years ago

Hi, we added a gist here

https://gist.github.com/jjallaire/92740db8588ce5e62bc8863e487c2134

that shows a way to do it (without making any changes to keras).

Let us know how this works for you.

philippmuench commented 4 years ago

Let us know how this works for you.

@skeydan It works, thanks! However, I am not sure how I can add this to a sequential model? Can you provide a short example?

What I have tried:

model1 <- keras::keras_model_sequential()
model1 %>% layer_multiplicative_lstm(40, input_shape = c(80, 6), return_sequences = TRUE)

which leads to

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  TypeError: The added layer must be an instance of class Layer. Found: <multiplicative_lstm.MultiplicativeLSTM object at 0x18ed50450> 

and

keras::use_implementation("tensorflow")
model2 <- keras::keras_model_sequential()
model2 %>% layer_multiplicative_lstm(40, input_shape = c(80, 6), return_sequences = TRUE)

leads to a different error

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  ValueError: Layer multiplicative_lstm_8 was called with an input that isn't a symbolic tensor. Received type: <class 'tensorflow.python.keras.engine.sequential.Sequential'>. Full input: [<tensorflow.python.keras.engine.sequential.Sequential object at 0x18ed50490>]. All inputs to the layer should be tensors. 

However, I can reproduce the example from here

model <- keras_model_sequential()
model %>% layer_dense(units = 256, input_shape = c(784)) 
model %>% layer_antirectifier()
> tensorflow::tf$`__version__`
[1] "2.0.0"

> sessionInfo()
R version 3.6.1 (2019-07-05)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS Catalina 10.15.1

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib

Random number generation:
 RNG:     Mersenne-Twister 
 Normal:  Inversion 
 Sample:  Rounding 

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] deepG_0.2.0

loaded via a namespace (and not attached):
 [1] pkgload_1.0.2        bit64_0.9-7          jsonlite_1.6         hdf5r_1.3.0         
 [5] assertthat_0.2.1     stats4_3.6.1         remotes_2.1.0        slam_0.1-46         
 [9] sessioninfo_1.1.1    pillar_1.4.2         backports_1.1.5      lattice_0.20-38     
[13] glue_1.3.1           reticulate_1.13      digest_0.6.22        XVector_0.26.0      
[17] colorspace_1.4-1     Matrix_1.2-17        plyr_1.8.4           tm_0.7-7            
[21] pkgconfig_2.0.3      devtools_2.2.1       zlibbioc_1.32.0      purrr_0.3.3         
[25] scales_1.0.0         processx_3.4.1       whisker_0.4          Rtsne_0.15          
[29] tibble_2.1.3         generics_0.0.2       IRanges_2.20.0       ggplot2_3.2.1       
[33] usethis_1.5.1        ellipsis_0.3.0       withr_2.1.2          keras_2.2.5.0       
[37] lazyeval_0.2.2       BiocGenerics_0.32.0  NLP_0.2-0            cli_1.1.0           
[41] magrittr_1.5         crayon_1.3.4         memoise_1.1.0        ps_1.3.0            
[45] tokenizers_0.2.1     fs_1.3.1             SnowballC_0.6.0      xml2_1.2.2          
[49] pkgbuild_1.0.6       tools_3.6.1          prettyunits_1.0.2    hms_0.5.2           
[53] ArgumentCheck_0.10.2 stringr_1.4.0        S4Vectors_0.24.0     munsell_0.5.0       
[57] callr_3.3.2          Biostrings_2.54.0    compiler_3.6.1       rlang_0.4.1         
[61] grid_3.6.1           rstudioapi_0.10      base64enc_0.1-3      testthat_2.3.0      
[65] gtable_0.3.0         abind_1.4-5          roxygen2_6.1.1       R6_2.4.0            
[69] tfruns_1.4           dplyr_0.8.3          tensorflow_2.0.0     bit_1.1-14          
[73] zeallot_0.1.0        commonmark_1.7       rprojroot_1.3-2      readr_1.3.1         
[77] desc_1.2.0           stringi_1.4.3        parallel_3.6.1       Rcpp_1.0.2          
[81] vctrs_0.2.0          tidyselect_0.2.5     xfun_0.10           
>