Calamari-OCR / calamari

Line based ATR Engine based on OCRopy
Apache License 2.0
1.04k stars 209 forks source link

output_dir parameter does not work as described in doc string #309

Open andbue opened 2 years ago

andbue commented 2 years ago
bertsky commented 1 year ago
  • for PAGE: at the moment, output_dir only defines the dir for extended predictions

Indeed, this is quite surprising (and annoying – you need to prevent clashes between multiple runs by using different extensions).

IMO all we need to do is add a output_dir attribute to the PAGE reader – to be set during store_text_prediction() and subsequently to be applied during store().

What do you think?

bertsky commented 1 year ago

In full:

diff --git a/calamari_ocr/ocr/dataset/datareader/pagexml/reader.py b/calamari_ocr/ocr/dataset/datareader/pagexml/reader.py
index 71a4da5..d9084d0 100644
--- a/calamari_ocr/ocr/dataset/datareader/pagexml/reader.py
+++ b/calamari_ocr/ocr/dataset/datareader/pagexml/reader.py
@@ -390,10 +390,13 @@ class PageXMLReader(CalamariDataGenerator[PageXML]):
     def prepare_store(self):
         self._last_page_id = None
         self._next_word_id = 0
+        self._output_dir = dict()

     def store_text_prediction(self, prediction, sample_id, output_dir):
         sentence = prediction.sentence
         sample = self.sample_by_id(sample_id)
+        output_dir = output_dir or os.path.dirname(sample['page_id'])
+        self._output_dir[sample['page_id']] = output_dir
         ns = sample["ns"]
         line = sample["xml_element"]
         textequivxml = line.find('./ns:TextEquiv[@index="{}"]'.format(self.params.text_index), namespaces=ns)
@@ -440,8 +443,11 @@ class PageXMLReader(CalamariDataGenerator[PageXML]):
                 desc="Writing PageXML files",
                 total=len(self.params.xmlfiles),
             ):
-                page = self.pages(split_all_ext(xml)[0])
-                with open(split_all_ext(xml)[0] + extension, "w", encoding="utf-8") as f:
+                page_id = split_all_ext(xml)[0]
+                page = self.pages(page_id)
+                path = os.path.join(self._output_dir[page_id],
+                                    filename(xml) + extension)
+                with open(path, "w", encoding="utf-8") as f:
                     f.write(etree.tounicode(page.getroottree(), pretty_print=True))

     @staticmethod
@@ -619,7 +625,9 @@ class PageXMLReader(CalamariDataGenerator[PageXML]):

     def _store_page(self, extension, page_id):
         page = self.pages[page_id]
-        with open(split_all_ext(page_id)[0] + extension, "w", encoding="utf-8") as f:
+        path = os.path.join(self._output_dir[page_id],
+                            filename(page_id) + extension)
+        with open(path, "w", encoding="utf-8") as f:
             f.write(etree.tounicode(page.getroottree(), pretty_print=True))

     def _sample_iterator(self):