tkuanlun350 / 3DUnet-Tensorflow-Brats18

3D Unet biomedical segmentation model powered by tensorpack with fast io speed
202 stars 68 forks source link

out of memory after several online eval iterations #8

Open huangmozhilv opened 5 years ago

huangmozhilv commented 5 years ago

The training stage is well, consuming about 10 GB CPU memory. However, memory increases quickly once online eval (called by EvalCallback) starts, and amounts to 60G after several eval iterations. Did others observe the same problem? How do you solve it?

tkuanlun350 commented 5 years ago

I didn'y met this problem before. 60G memory usage sounds impossible to me.

Can you try to use offline evaluation to see if the problem still exist ? ex: python3 train.py --load /path/to/ckpt/ --evaluate ...

huangmozhilv commented 5 years ago

I changed to offline_evaluate() by using

model_path='train_log/unet3d/model-5'
pred = OfflinePredictor(PredictConfig(
                    model=get_model(modelType="inference"),
                    session_init=get_model_loader(model_path),
                    input_names=['image'],
                    output_names=get_model_output_names()))

something goes wrong, the log shows:

[1119 19:40:30 @sessinit.py:117] Restoring checkpoint from train_log/unet3d/model-5 ... INFO:tensorflow:Restoring parameters from train_log/unet3d/model-5 [1119 19:40:27 @sessinit.py:90] WRN The following variables are in the checkpoint, but not found in the graph: global_step:0, learning_rate:0

tkuanlun350 commented 5 years ago

The warning is as expected, variable global_step:0, learning_rate:0 are only used in training mode. Is there other exception ?

huangmozhilv commented 5 years ago

I see. Thank you. Finally, I found a solution to avoid OOM:

import threading
thread = threading.Thread(target=self._eval(), name='self._eval')
thread.start()
thread.join() # wait threading to finish to close it to save memory
tkuanlun350 commented 5 years ago

Nice ! It will be nice if you submit a pull request ! Maybe other people are facing the problem.

huangmozhilv commented 5 years ago

Hi @tkuanlun350 , now I'm adapting your code to LiTS (for liver segmentation) challenge. The 3D CT volume is much larger than BRATS, i.e. 512x512xn (n = 100~1000). In this case, I can't do online prediction with EvalCallBack() even the threading block as above is applied. The training process takes about 30GB of memory. When it comes to EvalCallBack() after several epochs (I used 10 epochs, each epoch = 250 steps), the memory soon increases to more than 60 GB and triggers OOM issue. I checked the processes with htop and some GB memory can be released when killing the pids marked as D status during EvalCallBack(). It seems that this is a deadlock problem. However, I can't guess where a deadlock can happen in the code. Could you give any clues to the solution? Thank you very much.

My major revision to your code

  1. In def get_eval_dataflow(), I changed mapdatacomponent() to mapdata to adapt to my inputs here.
  2. In class EvalCallback(callback), I wrapped threading functions to _eval() in def _trigger_epoch.

Below is my revised code:


def get_eval_dataflow(images_path, labels_path):
    # #if config.CROSS_VALIDATION:
    # imgs = SEG_loader.load_from_file(config.BASEDIR, config.VAL_MODE)
    # # no filter for training
    files = data_loader.load_files(images_path, labels_path)
    files = list(files)

    ds = DataFromListOfDict(files, ['id', 'image_data', 'gt', 'preprocessed']) # return yield [files(index)['id'], files(index)['image_data'], file(index)['gt'], file(index)['preprocessed']] (i.e. split-join each dict to a list)
    ds.reset_state()
    def eval_preprocess(data):
        if config.NO_CACHE:
            gt, im = data[2], data[1]
            volume_list, label, weight, original_shape, bbox = crop_brain_region(im, gt)
            batch = sampler3d_whole(volume_list, label, weight, original_shape, bbox, gt)
            # logger.info('volume_list[0].shape:{}, original_shape:{}, batch_images shape:{}, batch_original shape:{}, batch_bbox shape:{}'.format(volume_list[0].shape, original_shape, batch['images'].shape, str(batch['original_shape']), str(batch['bbox'])))
            for key in batch.keys():
                if isinstance(batch[key], np.ndarray):
                    batch[key] = np.ascontiguousarray(batch[key])
        else:
            volume_list, label, weight, original_shape, bbox = data[3]
            batch = sampler3d_whole(volume_list, label, weight, original_shape, bbox, gt)

        del volume_list
        del label
        del weight
        gc.collect()
        return [data[0], data[1], data[2], batch]

    ds = MapData(ds, eval_preprocess) # should return yield list to pass to PrefetchDataZMQ()?
    ds = PrefetchDataZMQ(ds, 1)

    del files
    return ds

class EvalCallback(Callback):
    def __init__(self, images_path, labels_path):
        self.chief_only = True # Only run this callback on chief training process.
        self.images_path = images_path
        self.labels_path = labels_path

    def _setup_graph(self):
        # ipdb.set_trace()
        self.pred = self.trainer.get_predictor(
            ['image'], get_model_output_names())
        self.df = get_eval_dataflow(self.images_path, self.labels_path)

    def _eval(self):
        logger.info('Evaluate after epoch {}'.format(self.epoch_num))
        scores = eval_brats(self.df, lambda img: segment_one_image(img, [self.pred], is_online=True), outdir=config.OUTDIR, epoch_num=self.epoch_num)
        fo = open(os.path.join(os.getcwd(),'eval_res.csv'), mode='a+')
        wo = csv.writer(fo, delimiter=',')
        for k, v in scores.items():
            self.trainer.monitors.put_scalar(k, v)
            wo.writerow([config.TASK, self.epoch_num, config.STEP_PER_EPOCH, k, v, tinies.datestr()])
        fo.flush()

    def _trigger_epoch(self):
        if self.epoch_num > 0 and self.epoch_num % config.EVAL_EPOCH == 0:
            # self._eval()
            thread = threading.Thread(target=self._eval(), name='self._eval')
            thread.start()
            thread.join() # wait threading to finish to close it to save memory
mini-Shark commented 5 years ago

@huangmozhilv Hi, I'm working on online prediction too. did you have solve this problem ? could you please give me some advises on this problem ? Thanks

huangmozhilv commented 5 years ago

@mini-Shark Yes. I found the reason resides in QueueInput(get_train_dataflow()). get_train_dataflow(), the process from 'loading data from hard disk' to 'get data preprocessed for queue', is on going all the time. During online evaluation, when the queue of the preprocessed train data is full, newly preprocessed train data is stored in memory increasingly, resulting in out of memory problem. Since we built our pipeline with pytorch from scratch and borrowed some code from this repo, we solved this problem by writing our own QueueInput() like function using built-in python module multiprocessing.

mini-Shark commented 5 years ago

@huangmozhilv Saaad...Is there have some methods to avoid this situation ?may be this problem is stupid, but i didn't have time to rewrite whole pipeline : (

huangmozhilv commented 5 years ago

I have no idea using tensorpack.

mini-Shark commented 5 years ago

@huangmozhilv Anyway, thanks for your reply

tkuanlun350 commented 5 years ago

@mini-Shark Your problem is that you cannot do evaluation and training at the same time because of memory bottleneck ? You can try to change config. NO_CACHE = True to online load data. When config. NO_CACHE = False, the images are all loaded and preprocessed to accelerate the training but it will consume a lot more memory.

If the problem is not solved, I think we can open a new issue for the problem for better discussion.

huangmozhilv commented 5 years ago

@tkuanlun350 It's a different problem. I think @mini-Shark should also set 'NO_CACHE = True'. The problem is that if the online evaluation takes long time(e.g. we have half of the BRATS dataset to online evaluation), the queue of training will get full, and preprocessed data from get_train_dataflow() will temporally be stored in memory instead of in queue.

tkuanlun350 commented 5 years ago

@huangmozhilv Thanks ! I will try to investigate tensorpack source code to figure out a workaround. The ugly solution is that you discard the queue input and just use feed_dict.

huangmozhilv commented 5 years ago

@tkuanlun350 Thank you.

mini-Shark commented 5 years ago

@huangmozhilv @tkuanlun350 Thanks for you guys help me. now, maybe I found a trade-off solution is that add a additional parameter on 'PrefetchDataZMQ' when define 'get_train_dataflow()'. There 'PrefetchDataZMQ(ds, nr_proc=1, hwm=50)' have a default 'hwm=50' parameter, which control queue size of dataflow. I modify it to 'hwm=2'. And I also have modified 'get_eval_dataflow()' for don't load all validation data one time.

I'm not sure this will work properly, but it didn't raise OOM now(I have 64GB memory).