rstudio / keras3

R Interface to Keras
https://keras3.posit.co/
Other
838 stars 283 forks source link

Cannot pass a list of tensors into fit() (as the x parameter). #1234

Closed maspotts closed 3 years ago

maspotts commented 3 years ago

Hi: I'm trying to use keras for the first time from R: I've got it installed and working (after a lot of trial and error with reticulate!). I'm trying to use the pre-trained yamnet (https://tfhub.dev/google/yamnet/1) to extract some useful features from my audio files (ie. use its non-final layers as a feature generator), and then use the resulting tensors as inputs to the simple 2-layer classifier from here: https://blogs.rstudio.com/ai/posts/2018-01-11-keras-customer-churn, so that I can predict my target variable. Everything seems to work great: I can run my WAV files through yamnet and generate the tensor features; I can specify and compile the model: but then finally I cannot figure out how pass the list of tensors as input to fit().

To be specific: I start out with wavs: a list of 2030 WAV files (converted to 16KHz mono in the [-1,1] range as yamnet requires) and a corresponding 2030-element vector of scalars y. I first map the WAV files to tensors via yamnet's embedding (I have to take the 2nd element of each returned list, which is the embedding (the first and third elements are the scores and spectrograms):

yamnet_layer <- layer_hub(handle = "https://tfhub.dev/google/yamnet/1")
inputs <- wavs %>% lapply(function(x) x@left) %>% pblapply(yamnet_layer) %>% lapply(function(x) x[[2]])

So now inputs is a list of 2030 embeddings (tensors). eg:

> head(inputs, 2)
[[1]]
tf.Tensor(
[[0.         0.5778516  0.         ... 0.         0.         0.        ]
 [0.         0.         0.04971324 ... 0.         0.         0.        ]
 [0.         0.09253532 0.         ... 0.         0.         0.        ]], shape=(3, 1024), dtype=float32)

[[2]]
tf.Tensor(
[[0.         0.46789336 0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.         1.0887319  0.57517177 ... 0.         0.         0.        ]], shape=(3, 1024), dtype=float32)

So far, so good. The tensors are all shape (N, 1024), where N=1,2,3,4, etc. (depends on the length of the WAV files). Now (following https://blogs.rstudio.com/ai/posts/2018-01-11-keras-customer-churn/) I define my model, taking care to specify the input_shape as c(NULL, 1024) to indicate that the first dimension can vary (1,2,3,etc.) and the second dimension is constant (1024):

    model <- keras_model_sequential()

    ## First hidden layer
    model %>%
        layer_dense(
            units = 16,
            kernel_initializer = "uniform",
            activation = "relu",
            input_shape =  c(NULL, 1024)) %>%
        ## Dropout to prevent overfitting
        layer_dropout(rate = 0.1) %>%
        ## Second hidden layer
        layer_dense(
            units = 16,
            kernel_initializer = "uniform",
            activation = "relu") %>%
        ## Dropout to prevent overfitting
        layer_dropout(rate = 0.1) %>%
        ## Output layer
        layer_dense(
            units = 1,
            kernel_initializer = "uniform",
            activation = "sigmoid") %>%
        ## Compile ANN
        compile(
            optimizer = 'adam',
            loss = 'binary_crossentropy',
            metrics = c('accuracy')
        )

Still no problems. But now I want to fit the model to my response variable y. At this point inputs is a 2030-element list of tensors, and y is a 2030-element numeric vector:

> head(y)
[1] 0 0 0 0 0 0

Now I want to train the classifier, but this is where I get stuck: my list of tensors (inputs) seems not acceptable as an x value:

    history <- fit(
        object = model, 
        x = inputs,
        y = y,
        batch_size = 50, 
        epochs = 35,
        validation_split = 0.30
    )

I get this error:

Error in py_call_impl(callable, dots$args, dots$keywords) :
  ValueError: Data cardinality is ambiguous:
  x sizes: 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 1, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2

and I can't for the life of me figure out how to re-format inputs so that it will be acceptable. I tried as.matrix(inputs) and as.array(inputs) but get:

Error in py_call_impl(callable, dots$args, dots$keywords) :
  Matrix type cannot be converted to python (only integer, numeric, complex, logical, and character matrixes can be converted

I tried unlist(inputs) but oddly that seems to be a no-op:

> unlist(inputs)
[[1]]
tf.Tensor(
[[0.         0.5778516  0.         ... 0.         0.         0.        ]
 [0.         0.         0.04971324 ... 0.         0.         0.        ]
 [0.         0.09253532 0.         ... 0.         0.         0.        ]], shape=(3, 1024), dtype=float32)

[[2]]
tf.Tensor(
[[0.         0.46789336 0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.         1.0887319  0.57517177 ... 0.         0.         0.        ]], shape=(3, 1024), dtype=float32)

(the output is still a list). I see there are some tf$Tensor and tf$TensorArray objects, and I've had a poke around but I can't get anywhere with those either: eg:

> ta <- tf$TensorArray(tf$float32, size = 2030L)
> ta[[1]] <- inputs[[1]]
Error in py_set_attr_impl(x, name, value) :
  Expecting a single string value: [type=double; extent=1].

I also tried tf$ragged$stack(inputs) but got:

Error in py_call_impl(callable, dots$args, dots$keywords) :
  ValueError: `validation_split` is only supported for Tensors or NumPy arrays, found following types in the input: [<class 'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor'>]

So if anyone can tell me how to restructure my inputs so that I can pass it into fit() I will be hugely grateful! This is my first opportunity to try to solve a real problem with tensorflow/keras in R and I can't wait to get past this first hurdle! I apologise if this turns out to be user error, but I couldn't a more relevant forum to ask this question.

Many thanks, in advance,

Mike

t-kalinowski commented 3 years ago

Hi, thanks for filing.

The issue you’re encountering is due to the variable/undefined dimension in your training data. Supplying an undefined dimension is just delaying providing the necessary information until a later time. Eventually, keras will need to materialize an actual array of numbers with an actual size in order to do some computation. Most often, a user supplied “None/NULL” value for one of the dimensions carries with it an implicit promise: the size of this dimension will have no impact on the size of the layer weights.

For example, with a simple dense layer, calling it with different batch sizes is fine, because it has no impact on the layer weight dimensions.

library(keras)
layer <- layer_dense(units = 16, use_bias = FALSE)
layer(k_random_uniform(c(1, 3)))$shape
## (1, 16)
layer(k_random_uniform(c(11, 3)))$shape
## (11, 16)
layer(k_random_uniform(c(111, 3)))$shape
## (111, 16)
layer$weights[[1]]$shape
## (3, 16)

However, if you try to call it with different final dimensions, you get an error:

layer <- layer_dense(units = 16, use_bias = FALSE)
layer(k_random_uniform(c(1, 3)))$shape
## (1, 16)
layer(k_random_uniform(c(1, 33)))$shape |> try()
## Error in py_call_impl(callable, dots$args, dots$keywords) : 
##   ValueError: Input 0 of layer dense_1 is incompatible with the layer: expected axis -1 of input shape to have value 3 but received input with shape (1, 33)
## 
## Detailed traceback:
##   File "/home/tomasz/.local/share/r-miniconda/envs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1013, in __call__
##     input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
##   File "/home/tomasz/.local/share/r-miniconda/envs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/keras/engine/input_spec.py", line 255, in assert_input_compatibility
##     ' but received input with shape ' + display_shape(x.shape))

This is because the matrix multiplication you’re asking keras to perform is impossible. The first call to the layer resulted the layer building its weights to a size of (3, 16), so they can be multiplied with the input that has shape (1, 3). But then the 2nd set of inputs are incompatible with that shape: a matrix of shape (1, 33) can not be multiplied with a matrix of shape (3, 16). To do the multiplication, Keras would need to discard the layer weights and then re-initialize a new set of weights with a different size.

layer <- layer_dense(units = 16, use_bias = FALSE)
layer(k_random_uniform(c(1, 3)))$shape
## (1, 16)
layer(k_random_uniform(c(1, 33)))$shape |> try() # error, wrong shape
## Error in py_call_impl(callable, dots$args, dots$keywords) : 
##   ValueError: Input 0 of layer dense_2 is incompatible with the layer: expected axis -1 of input shape to have value 3 but received input with shape (1, 33)
## 
## Detailed traceback:
##   File "/home/tomasz/.local/share/r-miniconda/envs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1013, in __call__
##     input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
##   File "/home/tomasz/.local/share/r-miniconda/envs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/keras/engine/input_spec.py", line 255, in assert_input_compatibility
##     ' but received input with shape ' + display_shape(x.shape))
layer$weights[[1]]$shape # current layer weights shape
## (3, 16)
layer$build(shape(NULL, 33)) # rebuild for a new input shape
layer$weights[[1]]$shape # the new layer weights shape
## (33, 16)
layer(k_random_uniform(c(1, 33)))$shape # now it works
## (1, 16)

In your example, this is effectively what you’re attempting to do. In the fit call you’re providing a variable shape input that is (implicitly) asking keras to train 3 different models of 3 different sizes, corresponding to how many frames yamnet output from the wav file (1, 2, or 3). Keras is rightfully complaining that this is impossible.

To train a model over the yamnet feature outputs, you’re going to have to come up with an architecture that can accommodate that variable dimension, or do some additional pre-processing on your data in order to fix that variable dimension to a single size. Some common strategies here include: padding to a common length, pooling, downsampling, or converting to a recurrent architecture.

maspotts commented 3 years ago

Wow, thank you for a very clear and informative explanation: that totally makes sense. In fact I did workaround the problem by averaging the N=2,3,... frames down to N=1 for each sample, but I now understand that was an appropriate and necessary step given the architecture. Thanks again: I really appreciate the resolution!

Mike