Open milesial opened 1 year ago
The shape inference logic in the ONNX definition looks a bit strange to me:
bool is_onesided = static_cast<bool>(getAttribute(ctx, "onesided", 0));
if (is_onesided) {
dft_size = is_onesided ? ((dft_size >> 1) + 1) : dft_size;
}
auto n_dfts = static_cast<int64_t>((signal_size - dft_size) / static_cast<float>(frame_step_value)) + 1;
// The output has the following shape: [batch_size][frames][dft_unique_bins][2]
ONNX_NAMESPACE::TensorShapeProto result_shape_proto;
result_shape_proto.add_dim()->set_dim_value(input_shape.dim(0).dim_value()); // batch size
result_shape_proto.add_dim()->set_dim_value(n_dfts);
result_shape_proto.add_dim()->set_dim_value(dft_size);
result_shape_proto.add_dim()->set_dim_value(2);
It looks like the author intended dft_size
to represent the window/frame length, but then repurposed it to represent the dft_unique_bins
mentioned in the description. That makes the math for computing n_dfts
wrong. I would instead expect something like this:
bool is_onesided = static_cast<bool>(getAttribute(ctx, "onesided", 0));
-if (is_onesided) {
- dft_size = is_onesided ? ((dft_size >> 1) + 1) : dft_size;
-}
+ int64_t dft_unique_bins = is_onesided ? ((dft_size >> 1) + 1) : dft_size;
auto n_dfts = static_cast<int64_t>((signal_size - dft_size) / static_cast<float>(frame_step_value)) + 1;
// The output has the following shape: [batch_size][frames][dft_unique_bins][2]
ONNX_NAMESPACE::TensorShapeProto result_shape_proto;
result_shape_proto.add_dim()->set_dim_value(input_shape.dim(0).dim_value()); // batch size
result_shape_proto.add_dim()->set_dim_value(n_dfts);
-result_shape_proto.add_dim()->set_dim_value(dft_size);
+result_shape_proto.add_dim()->set_dim_value(dft_unique_bins);
result_shape_proto.add_dim()->set_dim_value(2);
Describe the issue
The STFT op has the wrong "expected" shape that doesn't match it's output (correct) shape because of some off-by-one error.
Originally from https://github.com/pytorch/pytorch/pull/92087#issuecomment-1383235742
This prevents adding new nodes after the STFT node since there will be expected shape mismatch.
To reproduce
Here is a standalone repro that creates a graph with only a STFT op, and we see the runtime throw the issue. Also happens with optimization enabled.
Urgency
No response
Platform
Linux
OS Version
Ubuntu 22.10
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.13.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response