Open chenmoneygithub opened 2 years ago
I think this one is actually not about BertPreprocessor
so much as the Bert
functional model. Basically, you would see this issue if you did
unbatched_input = {
"token_ids": tf.ones((128,)),
"segment_ids": tf.ones((128,)),
"padding_mask": tf.ones((128,)),
}
model = keras_nlp.models.Bert.from_preset("bert_base_uncased_en")
model(unbatched_input)
We want to preserve the ability to run preprocess before batching, so it would be incorrect to "uprank" a batch dim at preprocessing time. So the question is how do we want our model to handle unbatched inputs.
I agree the current error is really bad. I'm not sure whether we should automatically uprank, or just error if unbatched input our passed in. Both seem better than what we currently have.
Our options seems like
# Option 1.
model(unbatched_input) -> OK (with an expand dims on the first axis)
# Option 2.
model(unbatched_input) -> Error("Inputs must be batched! Try ds.batch() or adding a rank to raw inputs.")
The interesting thing is that I think core Keras functional models already do some automatic upranking, and it is probably doing something wrong here (e.g. upranking on the last dim instead of the first dim). The first thing we should do is to dive into the functional model input code and figure out what is going wrong. Then let's use that to inform our decision on what our behavior is here.
I assume this means that the docstring example is wrong as well? Wish we had a way to run these automatically!
I do see the string passed in a iterable in all our unit tests.
It looks like the docstring example will not run because our vocabulary is missing special tokens, but that is unrelated.
Calling our preprocessing on an unbatched input is totally correct to do, and it will produce and unbatched output. This is generally true of our preprocessing layers, where we support batching first or batching later. This is consistent with Keras preprocessing offerings generally
ds = ds.map(preprocessor).batch(8)
ds = ds.batch(8).map(preprocessor)
I think it is helpful to push aside the preprocessing part of the question, to narrow in on the core question on this issue. If a user calls our model with a set of inputs with only a sequence dimension and no batch dimension (e.g. rank 1 inputs), how do we our model to respond?
Right now we give the very inscrutable error message Chen posted above, and I'm fairly sure if has to do with upstream Keras functional model code doing an incorrect uprank on the wrong axis. Fixing this will require doing some spelunking into functional model code and thinking about the correct behavior there, but we can do this independent of preprocessing first.
Also, re testing fenced code block, we absolutely could do this if we wanted to sink a little effort in. I can open up an issue.
This is fixed for the task level objects in https://github.com/keras-team/keras-nlp/pull/545.
But it is still the case if you pass an unbatched input to the backbone directly (e.g. you create a dict of inputs with shape (128,)
and no batch dimension), you get the confusing error in this issue description. I think this may actually be an issue with functional models in core Keras. I will leave this open to track, but this should now be non-blocking for 0.4 release.
See the code snippet below:
The code throws error:
For a smooth UX, the code snippet is expected to work, as it demonstrates how a single string flows in the Bert model.