kermitt2 / grobid

A machine learning software for extracting information from scholarly documents
https://grobid.readthedocs.io
Apache License 2.0
3.41k stars 444 forks source link

can I use GPU to train segmentation model? #964

Open majiajun0 opened 1 year ago

majiajun0 commented 1 year ago

I run grobid on ubuntu with a GPU NVDIA 3080ti, not in docker way. The grobid trains segmentation model with CPU defaultly. I want to train segmentation model with GPU, but I don't know how to config grobid. I am not sure whether segmentation training supports GPU for it uses CRF algorithm. The code is a little difficult to read for me, I will be grateful for explaining this to me!

kermitt2 commented 1 year ago

Hi @majiajun0 !

Yes you can :)

Good news, there is no dirty code to touch.

You just need:

1) modify the config file to indicate a deep learning architecture to be used instead of "wapiti", for instance:

- name: "segmentation"
      #engine: "wapiti"
      engine: "delft"
      wapiti:
        # wapiti training parameters, they will be used at training time only
        epsilon: 0.0000001
        window: 50
        nbMaxIterations: 2000
      delft:
        # deep learning parameters
        architecture: "BidLSTM_CRF_FEATURES"
        useELMo: false
        runtime:
          # parameters used at runtime/prediction
          max_sequence_length: 3000
          batch_size: 1
        training:
          # parameters used for training
          max_sequence_length: 3000
          batch_size: 10

(you can adapt the training parameters according to your GPU)

2) you can then train a segmentation model following you the new parameters

./gradlew train_segmentation

You should see the TensorFlow/DeLFT model information on the console and not the CRF epoch info.

I did a test today and with BidLSTM_CRF_FEATURES the training took me around 6 hours with a 1080Ti, still quite a lot of time, ~3 times faster than CRF training.

But it leads to poor results at this stage, around -2 points on the F1-score in average for the fields on the PMC eval set. I suppose it's because the features were designed for CRF and there are not enough "full instance" training data for a RNN. But runtime is still good.

Note 1: you might want to increase the max sequence length if you think that your documents might have more than 3000 lines (warning these are "lines" according to the PDF stream, usually more than the actual visual lines).

Note 2: Be sure to use a model with FEATURES channel because the segmentation model works at line level with layout information and most of the information is in the categorical features, with very limited text content (just the text prefix of a line).

Note 3: BERT-style models should work worst because of the 512 input size limit. They will be much slower in runtime but will train faster.

majiajun0 commented 1 year ago

@kermitt2 Thanks very much!