githubharald / CTCWordBeamSearch

Connectionist Temporal Classification (CTC) decoder with dictionary and language model.
https://towardsdatascience.com/b051d28f3d2e
MIT License
557 stars 160 forks source link

ValueError: the number of characters (chars) plus 1 must equal dimension 2 of the input tensor (mat) #67

Closed UniDuEChristianGold closed 1 year ago

UniDuEChristianGold commented 1 year ago

Hello and thank you for your SimpleHTR and CTCWordBeamSearch Repos.

The latter I am trying to use with my code, but I am stuck with the ValueError mentioned in the title: ValueError: the number of characters (chars) plus 1 must equal dimension 2 of the input tensor (mat)

In my code, I initialize the decoder with: self.wbs_decoder = WordBeamSearch(50, 'Words', 0.0, corpus.encode('utf8'), wbs_chars.encode('utf8'), word_chars.encode('utf8')) where wbs_chars is a string with 91 characters:

after predicting 5 text lines with 128 timestamps is stored in 'out' out = self.model_pred.predict_on_batch(x) while print(out.shape) -> (128, 5, 92)

With this, it should follow the required pattern: TxBx(C+1) and thus, I do not understand the ValueError which is raised while calling: label_str = self.wbs_decoder.compute(out)

Is there anything obvious I am missing?

What I have tried so far:

  1. increase wbs_chars by one 'space'character, although your test-files/examples are analog to passing the string with the length of C and not C+1.
  2. I was unsure about "dimension 2" and swapped the B with C+1 ending in out.shape -> (128,92,5) With both, I ended up with the above ValueError.

In advance I thank you for your help.

githubharald commented 1 year ago

Hi,

can you please dump the content of wbs_chars.encode('utf8') and word_chars.encode('utf8') by doing

print(wbs_chars.encode('utf8'))
print(word_chars.encode('utf8'))

The output of the RNN must have C+1 entries as it includes the special "CTC blank" character, there are C characters to be recognized, and the word characters should be less than C, e.g. C-1, as this does not include word separation characters like a whitespace.

To give an example: the RNN outputs the characters " AB~" where "~" denotes the special character, the characters that we can recognize by such a model are " AB", and the word characters are "AB", as we use the whitespace " " as a word separation character (as in most languages).

Here is an example of how to use it: https://github.com/githubharald/SimpleHTR/blob/master/src/model.py#L142 And this is where the error comes from, you can see the condition for this error check there: https://github.com/githubharald/CTCWordBeamSearch/blob/82824268694a541608a38940125ab14fa3993613/cpp/TFWordBeamSearch.cpp#L195

UniDuEChristianGold commented 1 year ago

Thank you for your answer. Yes, I understand that the word_chars is a smaller subset of chars.

Here is the printout: print(wbs_chars.encode('utf8')) b' !"#&\'()*+,-./|\0123456789:;?ABCDEFGHIJKLMNOPRQSTUVWXYZabcdefghijklmnopqrstuvwxyz|}\xc3\x84\xc3\x9c\xc3\x9f\xc3\xa4\xc3\xb6\xc3\xbc\xe2\x80\x9c\xe2\x80\x9e'

print(word_chars.encode('utf8')) b"'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x84\xc3\x96\xc3\x9c\xc3\x9f"

I added another printout to the NPWordBeamSearch.cpp (not sure if TF vs. NP is causing a difference here, but I doubt that, as the chars shouldn't be influenced by this.): std::cout << "maxC " << maxC << " m_numChars " << m_numChars <<'\n'; -> maxC 92 m_numChars 90 so, there is one character missing at m_numChars/wbs_chars as it should be 91.

print(len(wbs_chars)) -> 91(!) print(wbs_chars.encode('utf8')) print(word_chars.encode('utf8')) self.wbs_decoder = WordBeamSearch(50, 'Words', 0.0, corpus.encode('utf8'), wbs_chars.encode('utf8'), word_chars.encode('utf8'))

so it seems like that one character is lost during: m_numChars = m_lm->getAllChars().size();

I was able to track down the issue. I added the character | twice in my list. With getAllChars double characters are removed. Thank you so much for your help

githubharald commented 1 year ago

Good that you found the issue 👍 . Just as a side-mark: as you removed one character from your list, be sure that the order in which the characters occur now in the list is the same as they occur in the RNN output.