Open sigpro opened 5 years ago
I had a similar problem with the missing shape and the latest tensorflow. The following modification to tensorflow_binding/src/warpctc_op.cc
did the trick for me. You need to add the #include
and then the .SetShapeFn...
code and recompile. The code is copied and modified from here: https://www.tensorflow.org/guide/extend/op
#include "tensorflow/core/framework/shape_inference.h"
REGISTER_OP("WarpCTC")
.Input("activations: float32")
.Input("flat_labels: int32")
.Input("label_lengths: int32")
.Input("input_lengths: int32")
.Attr("blank_label: int = 0")
.Output("costs: float32")
.Output("gradients: float32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(3));
c->set_output(1, c->input(0));
return ::tensorflow::Status::OK();
});
To get even better results I modified the answer by @fginter based on the original Tensorflow .SetShapeFn
implementation of the CTCLoss function.
#include "tensorflow/core/framework/shape_inference.h"
using ::tensorflow::shape_inference::DimensionHandle;
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeHandle;
using ::tensorflow::Status;
REGISTER_OP("WarpCTC")
.Input("activations: float32")
.Input("flat_labels: int32")
.Input("label_lengths: int32")
.Input("input_lengths: int32")
.Attr("blank_label: int = 0")
.Output("costs: float32")
.Output("gradients: float32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle activations;
ShapeHandle flat_labels;
ShapeHandle label_lengths;
ShapeHandle input_lengths;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &activations));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &flat_labels));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &label_lengths));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &input_lengths));
// Get batch size from inputs and sequence_length, and update inputs
// with the merged batch_size since it is returned.
DimensionHandle batch_size;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(activations, 1), c->Dim(input_lengths, 0), &batch_size));
TF_RETURN_IF_ERROR(c->ReplaceDim(activations, 1, batch_size, &activations));
c->set_output(0, c->Vector(batch_size));
c->set_output(1, activations);
return Status::OK();
});
Traceback (most recent call last): File "/home/work/speech/libs/warp-ctc/tensorflow_binding/tests/test_warpctc_op.py", line 124, in test_m ultiple_batches_cpu self._test_multiple_batches(use_gpu=False) File "/home/work/speech/libs/warp-ctc/tensorflow_binding/tests/test_warpctc_op.py", line 121, in test multiple_batches use_gpu=use_gpu) File "/home/work/speech/libs/warp-ctc/tensorflow_binding/tests/test_warpctc_op.py", line 28, in _run_ct c self.assertShapeEqual(expected_costs, costs) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/test_util.py", line 1639, in a ssertShapeEqual np_array.shape, tf_tensor.get_shape().as_list(), msg=msg) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_shape.py", line 903, in as_list raise ValueError("as_list() is not defined on an unknown TensorShape.") ValueError: as_list() is not defined on an unknown TensorShape.
and when I print cost's shape ,it's unknown shape.And if I eval cost,the shape and value is right.So What's the ctc return? My env is tensorflow 1.10.1 and ubuntu 16.04