courao / ocr.pytorch

A pure pytorch implemented ocr project including text detection and recognition
MIT License
584 stars 133 forks source link

I made a pytorch-lightning implementation of your CTPN #47

Closed mathemusician closed 2 years ago

mathemusician commented 3 years ago

Hi! I've made a pytorch-lightning implementation of ctpn, mainly by using your code. Pytorch-lightning has many nice features, such as training with tpus/multiple gpus by changing one line of code, 16-bit precision, works on cpu (nice for testing), automatic learning rate finder... Would you be open to a pull request? Link to fork here! I'm in the process of converting your CRNN to pytorch-lightning as well.

Here's the simplified training loop:

datamodule = ICDARDataModule(
        config.icdar17_mlt_img_dir,
        config.icdar17_mlt_gt_dir,
        batch_size=1,
        num_workers=config.num_workers,
        shuffle=True,
    )

len_train_dataset = len(datamodule.train_data)

model = CTPN_Model()

trainer = pl.Trainer(gpus=1, # number of gpus, 0 if you want to use cpu
                       max_epochs=max_epochs,
                       log_every_n_steps=1,
                       callbacks=[LoadCheckpoint(config.pretrained_weights),
                                  InitializeWeights(),
                                  LossAndCheckpointCallback(config, len_train_dataset)])

trainer.fit(model, datamodule)
courao commented 3 years ago

Nice job!

mathemusician commented 3 years ago

Nice job!

Should I make it into my repository? It's not a "pure pytorch" implementation, but works on both CPU, GPU, and multiple GPUs. I've added visualization during the training process, and I'm planning on adding more OCR models. Pull request or no?