MhLiao / DB

A PyTorch implementation of "Real-time Scene Text Detection with Differentiable Binarization".
2.08k stars 477 forks source link

OOM on validation #320

Open kurbobo opened 2 years ago

kurbobo commented 2 years ago

I can run train on my GPU with batch size = 24, but on validation I've got OOM, nevertheless on validation I don't need gradients, only weights, so validation should use less GPU memory than train. I suppose, there is some bug in strange sequence of adding and removing weights from GPU on validation. I don't get OOM only when I set validation batch size = 12.

ChihabEddine98 commented 2 years ago

hey @kurbobo ! , actually, i encountered the same issue I think there is some validation cast on the code, I was able to make it work without setting smaller batch size by this code : (it could help you tho)

import gc

    def validate_step(self, data_loader, model, visualize=False):
        raw_metrics = []
        vis_images = dict()
        for i, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
            gc.collect()
            pred = model.forward(batch, training=False)
            output = self.structure.representer.represent(batch, pred)
            raw_metric, interested = self.structure.measurer.validate_measure(
                batch, output)
            raw_metrics.append(raw_metric)

            if visualize and self.structure.visualizer:
                vis_image = self.structure.visualizer.visualize(
                    batch, output, interested)
                vis_images.update(vis_image)
        metrics = self.structure.measurer.gather_measure(
            raw_metrics, self.logger)
        return metrics, vis_images

You can change the validate_step() method of Trainer inside DB/trainer.py

What I guess is a value parameter pass is missing between validate and validate_step but I didn't try to see more.

sir-timio commented 2 years ago

Hi, did you completely solve the issue with validation? I tried your snippet, but got Traceback (most recent call last): File "train.py", line 70, in <module> main() File "train.py", line 67, in main trainer.train() File "/workspace/DB/trainer.py", line 82, in train self.validate(validation_loaders, model, epoch, self.steps) File "/workspace/DB/trainer.py", line 149, in validate loader, model, True) File "/workspace/DB/trainer.py", line 171, in validate_step pred = model.forward(batch, training=False) File "/workspace/DB/structure/model.py", line 63, in forward loss_with_metrics = self.criterion(pred, batch) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/usr/local/lib/python3.6/dist-packages/torch/nn/parallel/data_parallel.py", line 150, in forward return self.module(*inputs[0], **kwargs[0]) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/workspace/DB/decoders/seg_detector_loss.py", line 194, in forward bce_loss = self.bce_loss(pred['binary'], batch['gt'], batch['mask']) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/workspace/DB/decoders/balance_cross_entropy_loss.py", line 40, in forward positive = (gt[:,0,:,:] * mask).byte() TypeError: list indices must be integers or slices, not tuple

windspirit95 commented 2 years ago

I can run train on my GPU with batch size = 24, but on validation I've got OOM, nevertheless on validation I don't need gradients, only weights, so validation should use less GPU memory than train. I suppose, there is some bug in strange sequence of adding and removing weights from GPU on validation. I don't get OOM only when I set validation batch size = 12.

As I understand from reading the config file, the training input size is 640 x 640, and validation input size is 736 x 736 (td500) or 800 x 800 (totaltext). That could be the reason why training batch size could be 24 in your case, but for validation the batch size should be smaller to avoid OOM issue :)