tensorflow / tensorflow

An Open Source Machine Learning Framework for Everyone
https://tensorflow.org
Apache License 2.0
186.04k stars 74.26k forks source link

Need tf.signal.rfft op in TFLite #27303

Closed jpangburn closed 2 years ago

jpangburn commented 5 years ago

System information

Provide the text output from tflite_convert If I pass the SELECT_TF_OPS option then I get:

Some of the operators in the model are not supported by the standard TensorFlow Lite runtime and are not recognized by TensorFlow. If you have a custom implementation for them you can disable this error with --allow_custom_ops, or by setting allow_custom_ops=True when calling tf.lite.TFLiteConverter(). Here is a list of builtin operators you are using: ADD, CAST, CONCATENATION, DIV, EXPAND_DIMS, FLOOR_DIV, FULLY_CONNECTED, GATHER, LOG, MAXIMUM, MINIMUM, MUL, PACK, PAD, RANGE, RESHAPE, SHAPE, SPLIT, SPLIT_V, STRIDED_SLICE, SUB, TRANSPOSE. Here is a list of operators for which you will need custom implementations: RFFT.

If I don't pass the SELECT_TF_OPS option then I get:

Some of the operators in the model are not supported by the standard TensorFlow Lite runtime. If those are native TensorFlow operators, you might be able to use the extended runtime by passing --enable_select_tf_ops, or by setting target_ops=TFLITE_BUILTINS,SELECT_TF_OPS when calling tf.lite.TFLiteConverter(). Otherwise, if you have a custom implementation for them you can disable this error with --allow_custom_ops, or by setting allow_custom_ops=True when calling tf.lite.TFLiteConverter(). Here is a list of builtin operators you are using: ADD, CAST, CONCATENATION, DIV, EXPAND_DIMS, FLOOR_DIV, FULLY_CONNECTED, GATHER, LOG, MAXIMUM, MINIMUM, MUL, PACK, PAD, RANGE, RESHAPE, SHAPE, SPLIT, SPLIT_V, STRIDED_SLICE, SUB, TRANSPOSE. Here is a list of operators for which you will need custom implementations: ComplexAbs, Cos, LinSpace, RFFT.

Also, please include a link to a GraphDef or the model if possible. The code to create the model is pretty short since I hardcoded a bunch of parameters for now:

with tf.Graph().as_default(), tf.Session() as sess:
    # input sound data as a waveform
    waveform = tf.placeholder(tf.float32, [None])
    # A Tensor of [batch_size, num_samples] mono PCM samples in the range [-1, 1].
    pcm = tf.math.scalar_mul(1/32768.0, waveform)

    # compute Short Time Fourier Transform
    stft = tf.signal.stft(pcm, frame_length=400, frame_step=160, fft_length=512)
    spectrogram = tf.abs(stft)

    # Warp the linear scale spectrograms into the mel-scale.
    num_spectrogram_bins = stfts.shape[-1].value
    lower_edge_hertz, upper_edge_hertz, num_mel_bins = 125.0, 7500.0, 64
    linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
      num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
      upper_edge_hertz)
    mel_spectrogram = tf.tensordot(
      spectrogram, linear_to_mel_weight_matrix, 1)
    mel_spectrogram.set_shape(spectrogram.shape[:-1].concatenate(
      linear_to_mel_weight_matrix.shape[-1:]))

    # Compute a stabilized log to get log-magnitude mel-scale spectrograms.
    log_mel_spectrogram = tf.log(mel_spectrogram + 1e-6)
    # with the model loaded and input/output tensors defined, convert to tf.lite
    converter = tf.lite.TFLiteConverter.from_session(sess, [waveform], [log_mel_spectrogram])
    converter.target_ops = [tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS]
    tflite_model = converter.convert()

Any other info / logs Assuming the SELECT_TF_OPS option produces a model that will work on TFLite on iOS, then I guess all I need is RFFT. Thank you!

jpangburn commented 5 years ago

I checked the master branch and I see this commit: https://github.com/tensorflow/tensorflow/commit/c77e7e56de56c624116cf9eea340b4f96f032c85#diff-ed4b7d597384e8e4b1210b7558a16640

Looks like someone already added what I needed, it's just not released yet :-) Sorry, I should have dug into the source before filing this issue.

If anyone else runs into this, it'll probably be available in the release after 1.13.1, or build from master.

jpangburn commented 5 years ago

I tried this with the nightly build from March 31, 2019 and unfortunately this op doesn't seem to actually work. The converter lets it go because it's on the whitelist but when I try to allocate tensors in TF Lite with a model using it, the allocation fails. Here's example code that requires numpy and tensorflow (the stft op calls the rfft op):

import tensorflow as tf
import numpy as np
target_path = "/tmp/log_mel_spectrogram.tflite"
with tf.Graph().as_default(), tf.Session() as sess:
    # input sound data as a waveform
    waveform = tf.placeholder(tf.float32, [None])
    # Convert to mono PCM samples in the range [-1, 1].
    pcm = tf.math.scalar_mul(1/32768.0, waveform)

    # compute Short Time Fourier Transform
    stft = tf.signal.stft(pcm, frame_length=400, frame_step=160, fft_length=512)

    # with the model loaded and input/output tensors defined, convert to tf.lite
    converter = tf.lite.TFLiteConverter.from_session(sess, [waveform], [stft])
    converter.target_ops = [tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS]
    tflite_model = converter.convert()
    open(target_path, "wb").write(tflite_model)

# verify we can execute the converted model (should print out an array)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=target_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)

It fails at the interpreter.allocate_tensors() line with the following:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-11-85ffc6e57901> in <module>
     21 # Load TFLite model and allocate tensors.
     22 interpreter = tf.lite.Interpreter(model_path=target_path)
---> 23 interpreter.allocate_tensors()
     24 
     25 # Get input and output tensors.

~/python_envs/tensorflow_master_monolithic/lib/python3.6/site-packages/tensorflow/lite/python/interpreter.py in allocate_tensors(self)
     93   def allocate_tensors(self):
     94     self._ensure_safe()
---> 95     return self._interpreter.AllocateTensors()
     96 
     97   def _safe_to_run(self):

~/python_envs/tensorflow_master_monolithic/lib/python3.6/site-packages/tensorflow/lite/python/interpreter_wrapper/tensorflow_wrap_interpreter_wrapper.py in AllocateTensors(self)
    104 
    105     def AllocateTensors(self):
--> 106         return _tensorflow_wrap_interpreter_wrapper.InterpreterWrapper_AllocateTensors(self)
    107 
    108     def Invoke(self):

RuntimeError: tensorflow/lite/kernels/split_v.cc:129 input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || input_type == kTfLiteInt16 was not true.Node number 1 (SPLIT_V) failed to prepare.

Changing the output tensor to the pcm tensor instead of stft works fine converter = tf.lite.TFLiteConverter.from_session(sess, [waveform], [pcm]) and prints out a result, so seems like although the RFFT op is whitelisted- it doesn't actually work?

jdduke commented 5 years ago

Hi @jpangburn, I'm curious if the model works if you change

converter.target_ops = [tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS]

to

converter.target_ops = [tf.lite.OpsSet.SELECT_TF_OPS]

My guess is that some of our builtin ops don't support the full suite of types (e.g., complex types) in all operators, which may be required when using rfft. Can you give that a try and let us know how it goes? Thanks.

jpangburn commented 5 years ago

Hi @jdduke , I tried that and I get

RuntimeError: Regular TensorFlow ops are not supported by this interpreter. Make sure you invoke the Flex delegate before inference.Node number 0 (Flex) failed to prepare.

I also read somewhere that TF lite doesn't support [None] as an input dimension so I tried fixing it to 16000 (one second worth of data at 16k sample rate) so the code looks like this:

import tensorflow as tf
import numpy as np
target_path = "/tmp/log_mel_spectrogram.tflite"
with tf.Graph().as_default(), tf.Session() as sess:
    # input sound data as a waveform
    waveform = tf.placeholder(tf.float32, [16000])
    # Convert to mono PCM samples in the range [-1, 1].
    pcm = tf.math.scalar_mul(1/32768.0, waveform)

    # compute Short Time Fourier Transform
    stft = tf.signal.stft(pcm, frame_length=400, frame_step=160, fft_length=512)

    # with the model loaded and input/output tensors defined, convert to tf.lite
    converter = tf.lite.TFLiteConverter.from_session(sess, [waveform], [stft])
    converter.target_ops = [tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS]
    tflite_model = converter.convert()
    open(target_path, "wb").write(tflite_model)

# verify we can execute the converted model (should print out an array)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=target_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)

But that gets the following error:

2019-04-08 09:26:39.664498: I tensorflow/lite/toco/import_tensorflow.cc:1336] Converting unsupported operation: RFFT 2019-04-08 09:26:39.675630: I tensorflow/lite/toco/graph_transformations/graph_transformations.cc:39] Before Removing unused ops: 8 operators, 20 arrays (0 quantized) 2019-04-08 09:26:39.675859: I tensorflow/lite/toco/graph_transformations/graph_transformations.cc:39] Before general graph transformations: 8 operators, 20 arrays (0 quantized) 2019-04-08 09:26:39.676099: I tensorflow/lite/toco/graph_transformations/graph_transformations.cc:39] After general graph transformations pass 1: 8 operators, 19 arrays (0 quantized) 2019-04-08 09:26:39.676148: F tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc:118] Check failed: dim_x == dim_y (80 vs. 400)Dimensions must match Fatal Python error: Aborted

I don't know why it has a problem with the dimensions as it works fine in regular TF. For example, if you throw this line in right after the stft = line it prints out an array just fine print(sess.run([stft], {waveform: np.array(np.random.random_sample((16000,)), dtype=np.float32)})).

jdduke commented 5 years ago

We're still working out some issues when running inference with the TF ops from Python (see https://www.tensorflow.org/lite/guide/ops_select#python_pip_package). We're hoping to have that resolved for the 1.14 release.

Have you tried running with a manual TFLite build (either with C++ or Java)? We'll be releasing pre-built .aar/cocoapods in the near future, which should make this easier.

jpangburn commented 5 years ago

I just tried it with an iOS build (from the tensorflow/lite/experimental/swift directory) and it gets the same error as the full code sample I provided earlier:

TensorFlow Lite Error: tensorflow/lite/kernels/split_v.cc:129 input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || input_type == kTfLiteInt16 was not true. TensorFlow Lite Error: Node number 1 (SPLIT_V) failed to prepare.

That's with the code sample that successfully creates the model but won't run it. In the Swift code I create a TFLite interpreter and pass it 16k worth of float32 test data. I don't know what the difference is between that build and the manual build you are referring to, are they the same or different?

jdduke commented 5 years ago

If you don't mind, feel free to forward me your converted ..tflite model (and/or the source model), and I'd be happy to troubleshoot further.

jpangburn commented 5 years ago

Oh, sorry, I see now. That build you pointed me to has support for "select tensorflow ops" which the build I was using probably doesn't.

Regarding the tflite model, you can generate it yourself from the code I provided which just requires a python environment with numpy and tensorflow:

import tensorflow as tf
import numpy as np
target_path = "/tmp/log_mel_spectrogram.tflite"
with tf.Graph().as_default(), tf.Session() as sess:
    # input sound data as a waveform
    waveform = tf.placeholder(tf.float32, [None])
    # Convert to mono PCM samples in the range [-1, 1].
    pcm = tf.math.scalar_mul(1/32768.0, waveform)

    # compute Short Time Fourier Transform
    stft = tf.signal.stft(pcm, frame_length=400, frame_step=160, fft_length=512)

    #print(sess.run([stft], {waveform: np.array(np.random.random_sample((16000,)), dtype=np.float32)}))
    # with the model loaded and input/output tensors defined, convert to tf.lite
    converter = tf.lite.TFLiteConverter.from_session(sess, [waveform], [stft])
    converter.target_ops = [tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS]
    tflite_model = converter.convert()
    open(target_path, "wb").write(tflite_model)

The commented out print line provides a sample way to invoke it (in regular TF). If you need to remove the [None] parameter then you could replace it with [16000] to let it process one second worth of data and that would be usable too. I've also provided the model zipped up if that's easier. log_mel_spectrogram.tflite.zip

dkashkin commented 4 years ago

@jpangburn can you please share the latest status on this? Have you been able to find a workaround?

jpangburn commented 4 years ago

@dkashkin I don't know what the current status of it on TF is, I wrote the RFFT out manually in code to solve it for my problem. I submitted this because the TFLite documentation asked people to let them know what operations are actually needed. As this is needed for audio signal processing to use the Google sound model Audioset it seemed like a fair request :-)

Writing it out manually was not easy for me and I imagine other people will want to use Audioset on a mobile device, so this is a useful thing to provide in TFLite- but again, I don't know its status.

jdduke commented 4 years ago

We've implemented rfft2d as a custom op which can be used optionally, we could probably do the same for rfft. I've filed an internal request to explore implementing this as a proper builtin op. Thanks for your patience.

dkashkin commented 4 years ago

Thanks @jpangburn! Can you please clarify - did you end up manually coding the entire preprocessing algorithm (generating spectrograms based on raw audio) or is there any opportunity to run Fast Fourier Fansform inside Tensorflow Light? PS. I am trying to improve performance of my app by offloading this heavy preprocessing to TFLite in hope that it will leverage GPU whenever available... Appreciate your help :)

jpangburn commented 4 years ago

Hi @dkashkin yes, I manually coded it to generate spectrograms from raw audio mainly with a bunch of vDSP calls in Swift (e.g. vDSP_DFT_Execute to execute the FFT with a vectorized implementation). If you're on iOS and unfamiliar with vDSP look here https://developer.apple.com/documentation/accelerate/vdsp. Supposedly highly optimized. It ran fast enough for what I was trying to do.

If you're doing the AudioSet stuff and are trying to follow VGGish but get it working with TFLite, as of when I submitted this- this was the best option for iOS IMHO. If you're writing for Android, then obviously you would need to use their equivalent. If @jdduke is able to get this in TFLite, that would make this WAAAY easier. I imagine we're not the only ones wanting to use AudioSet on TFLite :-) I'm glad you chimed in. Would help to write more portable code too if we could get that whole thing in TFLite- I'd like my stuff to run on Android as well but coding that again in another language is not something I really want to do haha!

dkashkin commented 4 years ago

Thanks @jpangburn I agree with you - everybody who tries to do signal processing for mobile is wasting a lot of time because TFLite still cannot do FFT :( PS. If you are interested I'd be happy to share my Kotlin code that generates MFCC on Android. The only problem is - it's dog slow (3 seconds per spectrogram on low end phones which is not acceptable). I keep my fingers crossed that @jdduke finds a quick solution...

PCerles commented 4 years ago

+1, this would be great to have

jdduke commented 4 years ago

This is great discussion, thanks all for the feedback.

As there are a number of related ops to FFT (RFFT(2D/3D), IRFFT(2D/3D), FFT(2D/3D), ComplexAbs), it would be good to know precisely which op variants are required for your models so we can prioritize support accordingly. If you can link to specific source models, that would be extremely helpful. Thanks!

PCerles commented 4 years ago

ComplexAbs, RFFT (2D) are the impacted ops for me. The part of the model impacted is just this https://www.tensorflow.org/api_docs/python/tf/signal/mfccs_from_log_mel_spectrograms line-for-line.

rryan commented 4 years ago

Hi! I'm the author of tf.signal. Sorry that tf.signal.stft doesn't fully work yet.

In case it's helpful, you can replace the RFFT and ComplexAbs in your network with an RDFT matrix multiply. Here's an example:

https://github.com/tensorflow/magenta/blob/cf80d19fc0c2e935821f284ebb64a8885f793717/magenta/music/melspec_input.py#L64-L90

That file has some other tflite compatibility tricks that are no longer required I believe (e.g. tf.signal.frame should be supported natively now).

rryan commented 4 years ago

Also, I believe tf.abs(tf.signal.stft(...)) works with SELECT_TF_OPS if you use the "new" converter (i.e. enable experimental_new_converter here).

dkashkin commented 4 years ago

@rryan big THANK YOU for tf.signal! It's awesome.

dkashkin commented 4 years ago

@rryan can you please give us some more details on why you think the experimental_new_converter should support tf.signal.stft? I just tried this option and although it generates a slightly different tflite file, the tflite inference still fails in this test:

import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
with tf.Graph().as_default(), tf.Session() as sess:
    waveform = tf.placeholder(tf.float32, [None])
    result = tf.abs(tf.signal.stft(waveform, frame_length=400, frame_step=160, fft_length=512))
    converter = tf.lite.TFLiteConverter.from_session(sess, [waveform], [result])
    converter.target_spec.supported_ops = [tf.lite.OpsSet.SELECT_TF_OPS]
    converter.experimental_enable_mlir_converter = True
    tflite_model = converter.convert()
    open("/content/test.tflite", "wb").write(tflite_model)
interpreter = tf.lite.Interpreter(model_path="/content/test.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()

RuntimeError: Regular TensorFlow ops are not supported by this interpreter. Make sure you apply/link the Flex delegate before inference.Node number 0 (FlexConst) failed to prepare.

rryan commented 4 years ago

@rryan big THANK YOU for tf.signal! It's awesome.

Thanks :) Glad it's useful.

@rryan can you please give us some more details on why you think the experimental_new_converter should support tf.signal.stft?

I think the binary you use to run the model would need the flex delegate linked in. Instructions here: https://www.tensorflow.org/lite/guide/ops_select#running_the_model

I don't know how to make this work for Python -- and the instructions here don't say how. Maybe @jdduke knows?

jpangburn commented 4 years ago

@jdduke Regarding what models, the one I was trying to get a TFLite model for is a Tensorflow model: https://github.com/tensorflow/models/tree/master/research/audioset/vggish. The tough one for me was a numpy call "np.fft.rfft" inside mel_features.py's stft_magnitude() method. From my original post, I also tried to use tf.signal.stft to replace this, but it was calling RFFT and that didn't work at the time in TFLite.

PCerles commented 4 years ago

Thanks a ton, I got this to work with tflite on tensorflow 1.15 on fixed size audio with the magenta code.

dkashkin commented 4 years ago

Thanks again @rryan! I just tested the tflite-compatible graph generated by Magenta. It works on Android, and seems to be roughly 5 times faster than my Kotlin based SFFT. The tflite model is quite large though (4.6MB for 224x224 spectrogram).

padoremu commented 4 years ago

@rryan Thanks! Using the magenta code helped, up to the point where I use tf.signal.mfccs_from_log_mel_spectrograms, which apparently needs RFFT / FlexRFFT as well. Any hint for a replacement to make it TFLite compatible without making use of TensorFlow ops (avoiding the use of flex delegates)?

rryan commented 4 years ago

@rryan Thanks! Using the magenta code helped, up to the point where I use tf.signal.mfccs_from_log_mel_spectrograms, which apparently needs RFFT / FlexRFFT as well. Any hint for a replacement to make it TFLite compatible without making use of TensorFlow ops (avoiding the use of flex delegates)?

Yea, that's right -- calculating MFCCs requires the DCT, which is implemented using an FFT :-/.

You would need to replace the DCT that it uses with something that uses an RDFT instead of RFFT. Here's a colab showing how to do it and checking for tf.signal compatibility: https://colab.research.google.com/drive/1C9jyM4CtW2Yn9xsPXt31cZnQuNe7sjU8

I just noticed that the Magenta _naive_rdft function doesn't quite match tf.signal.rfft's conventions, so if you're trying to use tf.signal.rfft in training but _naive_rdft in tf.lite, you'll need to make sure they match:

Here are the relevant code references in tf.signal that the above colab re-implements:

https://github.com/tensorflow/tensorflow/blob/3d8914bb88b666b813d2ce025ee10cc59fd39422/tensorflow/python/ops/signal/mfcc_ops.py#L97-L109 https://github.com/tensorflow/tensorflow/blob/3d8914bb88b666b813d2ce025ee10cc59fd39422/tensorflow/python/ops/signal/dct_ops.py#L120-L130

padoremu commented 4 years ago

@rryan Thank you, that's of great help!

Now I get on my 2.0.0.dev20190731 installation when calling allocate_tensors after loading the tflite model: RuntimeError: tensorflow/lite/kernels/strided_slice.cc ellipsis_mask is not implemented yet.Node number 22 (STRIDED_SLICE) failed to prepare.

I will comment further once I find the time to dig into this, hopefully next week.

rryan commented 4 years ago

Ah, bummer. I vaguely remembered that ... in a strided slice was unsupported. If you know the rank you can simply hard-code the dimensions. Otherwise you can replace strided slice with a tf.slice (which is much more awkward). Sorry about all the trouble :(.

padoremu commented 4 years ago

Thank you again for the great and fast support! I will go with the hard coded dimensions for now. Thanks again.

padoremu commented 4 years ago

As expected, replacing [..., :axis_dim] by [:, :, :axis_dim] and [..., :dct_coefficient_count] by [:, :, :dct_coefficient_count] worked in my case. Thanks! :)

dkashkin commented 4 years ago

@padoremu can you please share the full code snippet that works for you? This workaround sounds promising!

padoremu commented 4 years ago

@dkashkin My python file has 350 lines of code, so I wouldn't want to paste it. Is it allowed to attach a python file as a TXT file?

dkashkin commented 4 years ago

@padoremu I'd hugely appreciate if you can create a public gist on Github and post a link to it. Thank you!

padoremu commented 4 years ago

@dkashkin Thanks for the hint. There you go: public gist

I haven't compared the output to that of the TF functions yet though. The file also contains other test functions for problems with TF / TFLite in the context of audio processing, as described in the file header.

adelcast commented 4 years ago

Really helpful discussion, the magenta implementation worked for me as a replacement for tf.signal.stft. On the model I am using, I need to convert back to time-domain via tf.signal.inverse_stft, which is currently failing. Does anyone knows of a tflite-friendly implementation of inverse_stft? I can code it on Kotlin, but would be awesome if I could just do everything on the model.

antonyharfield commented 4 years ago

Thanks @rryan @padoremu for the very helpful discussion. I was able to save Yamnet as tflite using your suggestions. πŸ˜…

sailorbj commented 4 years ago

Thanks @rryan @padoremu for the very helpful discussion. I was able to save Yamnet as tflite using your suggestions. πŸ˜…

Hi @antonyharfield , to save Yamnet as tflite, did you mean to use rdft to replace stft and abs? Can you share that part of code?

Thanks and I've got the answer by reading rryan's reply. By the way, the tflite model for yamnet is about 16M? And I tried quantized tflite model, the predictions are much different from those by tflite. How about yours.

rybakov commented 4 years ago

Just in case if you are interested to train sounds detection model on Audioset and run it with TFLite, please feel free to try kws_streaming. It already has dozen of different neural network topologies (including attention one) which are outperforming standard conv neural nets. All models are compatible with TFLite and designed for running on mobile phone, also you can qunatize your models and run them in streaming mode if needed.

navid-a commented 4 years ago

Has anyone verified if one of the alternative TFLite compatible approaches produce the same results as the tf.signal.stft approach? I tried the code by @padoremu (which seems to be based on magneta) and while that code works and I get some reasonable log_mel_spectrograms, they are not exactly the same as the ones from tf.signal.

mylyu commented 4 years ago

I got nearly identical results. You may want to check some parameters inside the code, making sure they are consistent with yours.

On Sun, 19 Apr 2020 at 12:44, navid-a notifications@github.com wrote:

Has anyone verified if one of the alternative TFLite compatible approaches produce the same results as the tf.signal.stft approach? I tried the code by @padoremu https://github.com/padoremu (which seems to be based on magneta) and while that code works and I get some reasonable log_mel_spectrograms, they are not exactly the same as the ones from tf.signal.

β€” You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/tensorflow/issues/27303#issuecomment-616029963, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACN3SSUF3S3RNMBWOD5KH6TRNJ6ULANCNFSM4HCM6POQ .

antonyharfield commented 4 years ago

@sailorbj I have put my YAMNet code in a repo and also written up the steps if you are still interested: Converting YAMNet audio detection model for TFLite

JanSob commented 4 years ago

Is there any ETA on the arrival of the rfft op in tflite? It would make the whole audio-detection pipeline in Android and IOS much, much easier :/

andreselizondo-adestech commented 4 years ago

Working based on @padoremu 's gist, I replaced _fixed_frame() with tf.signal.frame. This allows for variable sized inputs and maintains compatibility with TFLite.

Edit is as simple as: image

Arvindia commented 4 years ago

@JanSo19 Seems like it's in their roadmap. https://groups.google.com/a/tensorflow.org/g/tflite/c/y4gAAavb4zs https://www.tensorflow.org/lite/guide/roadmap

Path-A commented 4 years ago

Has anyone had success with audio_microfrontend_op from tensorflow.lite.experimental.microfrontend.python.ops?

Zepyhrus commented 3 years ago

Working based on @padoremu 's gist, I replaced _fixed_frame() with tf.signal.frame. This allows for variable sized inputs and maintains compatibility with TFLite.

Edit is as simple as: image

sign, not working for me. Due to the BatchMul incompatibility with tflite

Zepyhrus commented 3 years ago

@dkashkin Thanks for the hint. There you go: public gist

I haven't compared the output to that of the TF functions yet though. The file also contains other test functions for problems with TF / TFLite in the context of audio processing, as described in the file header.

@padoremu Great work !! This is a damn saver. So how to do inference in TFlite if the input is fixed? (e.g. I'm training in batch 256 but it is obviously that I can't inference with a fixed batch of 256, padding empty will help?)

padoremu commented 3 years ago

@Zepyhrus Thanks, glad to hear that it helped! Unfortunately I can't help with further TF Lite related questions, as it pretty soon turned out that TF Lite was too heavyweight for our needs and TF Lite Micro neither supported STFT nor RNNs, and it didn't seem like that would happen soon. So we went a completely different way, which we are very happy with (but using TF for training).

Zepyhrus commented 3 years ago

@Zepyhrus Thanks, glad to hear that it helped! Unfortunately I can't help with further TF Lite related questions, as it pretty soon turned out that TF Lite was too heavyweight for our needs and TF Lite Micro neither supported STFT nor RNNs, and it didn't seem like that would happen soon. So we went a completely different way, which we are very happy with (but using TF for training).

Big fan of your code, playground shows your great expertise in both tf and audio processing. Do you mind sharing a little bit of your solution for this? We will stuck to tflite though.