facebookresearch / nougat

Implementation of Nougat Neural Optical Understanding for Academic Documents
https://facebookresearch.github.io/nougat/
MIT License
8.81k stars 561 forks source link

Option to run with float32 precision on CPU #108

Closed Vidminas closed 12 months ago

Vidminas commented 1 year ago

I borrowed the code from predict.py, to write a langchain PDF loader using Nougat.

Full nougat_loader.py ```python3 from functools import partial from typing import Optional, Iterator import re import torch from tqdm import tqdm from torch.utils.data import DataLoader from langchain.docstore.document import Document from langchain.document_loaders.pdf import BasePDFLoader from chatdocs.logger import logger class NougatPDFLoader(BasePDFLoader): """Load `PDF` files using Nougat (https://facebookresearch.github.io/nougat/).""" def __init__( self, file_path: str, *, num_workers: Optional[int] = 0, headers: Optional[dict] = None, ) -> None: """Initialize with file path.""" super().__init__(file_path, headers=headers) self.num_workers = num_workers try: from nougat import NougatModel from nougat.utils.checkpoint import get_checkpoint from nougat.utils.device import move_to_device except ImportError: raise ImportError( "`nougat` package not found, please install it with " "`pip install nougat-ocr`" ) checkpoint = get_checkpoint("nougat", download=True) self.model = NougatModel.from_pretrained(checkpoint) if torch.cuda.is_available(): self.batch_size = int( torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1000 * 0.3 ) if self.batch_size == 0: self.batch_size = 1 logger.warning("GPU VRAM is too small. Computing on CPU.") elif torch.backends.mps.is_available(): self.batch_size = 4 else: self.batch_size = 1 logger.warning("No GPU found. Conversion on CPU is very slow.") self.model = move_to_device(self.model) self.model.eval() def load(self) -> list[Document]: """Eagerly load the content.""" return list(self.lazy_load()) def lazy_load( self, ) -> Iterator[Document]: """Lazily load documents.""" import pypdf from nougat.utils.dataset import LazyDataset from nougat.postprocessing import markdown_compatible try: dataset = LazyDataset( pdf=self.file_path, prepare=partial(self.model.encoder.prepare_input, random_padding=False), ) except pypdf.errors.PdfStreamError: logger.info(f"Could not load file {str(self.file_path)}.") return dataloader = DataLoader( dataset, num_workers=self.num_workers, batch_size=self.batch_size, shuffle=False, collate_fn=LazyDataset.ignore_none_collate, ) for page_num, (sample, is_last_page) in enumerate( tqdm( dataloader, desc="Processing file {dataset.name} with {dataset.size} pages", ncols=80, position=0, leave=True, ) ): model_output = self.model.inference( image_tensors=sample, early_stopping=True ) # check if model output is faulty for i, output in enumerate(model_output["predictions"]): if output.strip() == "[MISSING_PAGE_POST]": # uncaught repetitions -- most likely empty page logger.warning(f"[MISSING_PAGE_EMPTY:{page_num}]") elif model_output["repeats"][i] is not None: if model_output["repeats"][i] > 0: # If we end up here, it means the output is most likely not complete and was truncated. logger.warning(f"Skipping page {page_num} due to repetitions.") else: # If we end up here, it means the document page is too different from the training domain. # This can happen e.g. for cover pages. logger.warning(f"[MISSING_PAGE_EMPTY:{i+1}]") else: output = markdown_compatible(output) output = re.sub(r"\n{3,}", "\n\n", output).strip() metadata = {"source": self.file_path, "page": page_num} yield Document(page_content=output, metadata=metadata) ```

I'm running this on a CPU-only machine on Windows 10. By default, the model parameters and inputs are converted to torch.bfloat16 tensors. Experimentally, it takes about 8 minutes to process 1 PDF page (the message "No GPU found. Conversion on CPU is very slow." is not kidding!)

However, I found that if I run everything without converting to torch.bfloat16 (leaving at the default torch.float32 precision), it runs much faster -- ~30 seconds per PDF page. To do this, I commented out the line self.model = move_to_device(self.model) in my PDF loader and the line image_tensors = image_tensors.to(torch.bfloat16) (plus its if condition) in nougat\model.py, class NougatModel, inference method:

if self.device.type != "mps":
    image_tensors = image_tensors.to(torch.bfloat16)
image_tensors = image_tensors.to(self.device)

last_hidden_state = self.encoder(image_tensors)

(which gets called by model_output = self.model.inference(image_tensors=sample, early_stopping=True).

It might be useful to add a parameter to model.inference to allow selecting the desired precision, and perhaps a cli flag for it too. What do you think?

Vidminas commented 1 year ago

To back this up, I ran a full test loading the same PDF from https://www.workerinfoexchange.org/uber-guidance-document

This is the result using torch.bfloat16: 100%███████████████████████████████████████████████████████████████████████| 57/57 [4:38:22<00:00, 293.03s/it]

And without moving weights or inputs (default torch.float32): 100%███████████████████████████████████████████████████████████████████████| 57/57 [22:39<00:00, 23.85s/it]

So more than a 12x speed-up.

System information:

OS Name: Microsoft Windows 10 Pro
Version: 10.0.19045 Build 19045
Processor: Intel(R) Core(TM) i7-7500U CPU @ 2.70GHz, 2904 Mhz, 2 Core(s), 4 Logical Processor(s)
Total Physical Memory: 15.9 GB
Python: Python 3.11.4 | packaged by Anaconda, Inc. | (main, Jul  5 2023, 13:38:37) [MSC v.1916 64 bit (AMD64)] on win32

I don't know if this a general rule for CPU inference or just an edge case. But being able to set precision to use would allow handling this, whether it is an edge case or not.

lukas-blecher commented 1 year ago

Oh wow. Thanks! Let me have a look

lukas-blecher commented 1 year ago

That didn't work for me. It actually got slower by 20s/it (from 43.58 to 59.26) using 11th Gen Intel(R) Core(TM) i5-1135G7 @ 2.40GHz.

lukas-blecher commented 12 months ago

I've added the option to use full precision