tensorflow / tfjs

A WebGL accelerated JavaScript library for training and deploying ML models.
https://js.tensorflow.org
Apache License 2.0
18.46k stars 1.92k forks source link

Prediction from tf.data.Dataset #2679

Closed loretoparisi closed 4 years ago

loretoparisi commented 4 years ago

TensorFlow.js version

latest version

Node version

$ node --version
v12.13.1

Describe the problem or feature request

In my python code I have a tf.data.Dataset where a file list is mapped to a tf.py_function:

 dataset = tf.data.Dataset.from_tensor_slices({
        'audio_id': list(filenames),
        'start': list(starts),
        'end': list(ends)
    })
    dataset = dataset.map(
        lambda sample: dict(
            sample,
            **audio_adapter.load_tf_waveform(
                sample['audio_id'],
                session=session,
                sample_rate=sample_rate,
                offset=sample['start'],
                duration=sample['end'] - sample['start'])),
        num_parallel_calls=2)

When I load the model into tfjs with the new tf.node.loadSavedModel I get the model with the following signature:

TFSavedModel {
  sessionId: 0,
  jsid: 0,
  inputNodeNames: {
    audio_id: 'Placeholder_1:0',
    mix_spectrogram: 'strided_slice_3:0',
    mix_stft: 'transpose_1:0',
    waveform: 'Placeholder:0'
  },
  outputNodeNames: {
    accompaniment: 'strided_slice_23:0',
    audio_id: 'Placeholder_1:0',
    vocals: 'strided_slice_13:0'
  },
  backend: NodeJSKernelBackend {
    binding: {},
    isGPUPackage: false,
    isUsingGpuDevice: false,
    tensorMap: DataStorage {
      backend: [Circular],
      dataMover: [Engine],
      data: [WeakMap],
      dataIdsCount: 0
    }
  },
  disposed: false
}

When in python the predict takes the dataset as input:

fn = lambda: get_dataset(
                    audio_adapter,
                    filenames_and_crops,
                    sample_rate,
                    n_channels, session)
with session.as_default():
    with session.graph.as_default():
        prediction = estimator.predict(fn,yield_single_examples=False)

How load a this dataset format into tfjs model.predict?

Code to reproduce the bug / link to feature request

-

tafsiri commented 4 years ago

Currently model.predict does not take a dataset but only takes tensors. You could however use functions like map custom loop to call the predict function of your loaded model with a realized tensor. You can use the other dataset methods to set up a pipeline to stream the data from disk. The generator function is probably the most flexible way to create a dataset from file streams (we don't have a built in utility to stream data from files in node).

cc @kangyizhang in case i missed anything/got anything wrong.

tafsiri commented 4 years ago

To clarify my message a bit the function passed to tf.data.generator would need to do the work that **audio_adapter.load_tf_waveform does after opening the files that you want (you can chain generators if you want).

kangyizhang commented 4 years ago

Like Yannick said, the current model.predict function in tfjs does not take a dataset as input. One option is to modify the model in python and change the input format as a tensor, then export the model.

rthadur commented 4 years ago

closing this issue , feel free to reopen if you need more info.

loretoparisi commented 4 years ago

@rthadur @kangyizhang @tafsiri thank you very much for your help, I get the big picture, but it is still pretty hard to figure out how to actually implement it. An example related to tf.Dataset and audio files would be worth, thanks a lot.

danwexler commented 3 years ago

Is there a feature request to add tf.data.Dataset as a supported input type for model.predict? In the meantime, I guess we usetf.stack and handle pre-loading ourselves? Is it worth looking at the Keras model.predictas an example of how to implement tf.data.Dataset support inmodel.predict?