microsoft / CNTK

Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit
https://docs.microsoft.com/cognitive-toolkit/
Other
17.49k stars 4.29k forks source link

memory leakage cause gpu out of memory #2677

Open mohamad-hasan-sohan-ajini opened 6 years ago

mohamad-hasan-sohan-ajini commented 6 years ago

Hi

I want to train a RNNLM. My vocabulary size is 48603 (cutoff=100) or 72294 (cutoff=50). Data is stored in CTF format which is suitable for sparse data. Training RNN with vocabulary size of 48603 goes fine, but with bigger vocabulary size, say 72294, after 135 epochs, the training process stops due to GPU out of memory!

Training code is as follow:

def train_network():
    # network define IO nodes
    feature = C.sequence.input_variable(vocab_size, name='feature')
    label = C.sequence.input_variable(vocab_size, name='label')

    # data reader
    my_source = create_reader('text_co50.ctf', vocab_size)
    input_map = {feature: my_source.streams.features, label: my_source.streams.labels}

    # define network
    my_lm = network(feature)

    # define loss & error
    cr = C.cross_entropy_with_softmax(my_lm, label, name='loss')
    err = C.classification_error(my_lm, label, name='error')

    # logging
    tbpw = C.logging.TensorBoardProgressWriter(freq=10, log_dir='log_my_lm_co50', model=my_lm)

    # optimizer
    lr, mm = C.learning_rate_schedule(.001, C.UnitType.sample), C.momentum_schedule(.9)
    opt = C.adam(my_lm.parameters, lr, mm)

    # trainer
    trainer = C.train.Trainer(my_lm, (cr, err), opt, [tbpw])

    # train network
    for epoch_nbr in range(epochs):
        print('epoch {}...'.format(epoch_nbr))
        t0 = time.time()
        for mb_nbr in range(mb_per_epoch):
            mb = my_source.next_minibatch(mb_size, input_map=input_map)
            trainer.train_minibatch(mb)
        print('epoch {} takes {} seconds long.\nsaving model.\n'.format(epoch_nbr, time.time() - t0))
        trainer.save_checkpoint(os.path.join(cwd, 'Models_co50', 'model_{}.cntk'.format(epoch_nbr)))

and the error:

epoch 135... epoch 135 takes 199.74186396598816 seconds long. saving model. epoch 136... CUDA failure 2: out of memory ; GPU=0 ; hostname=aj-pc ; expr=cudaMalloc((void) &deviceBufferPtr, sizeof(AllocatedElemType) AsMultipleOf(numElements, 2)) Traceback (most recent call last): File "train_rnnlm.py", line 73, in train_network() File "train_rnnlm.py", line 68, in train_network trainer.train_minibatch(mb) File "/home/aj/anaconda3/lib/python3.5/site-packages/cntk/train/trainer.py", line 181, in train_minibatch arguments, device) File "/home/aj/anaconda3/lib/python3.5/site-packages/cntk/cntk_py.py", line 2850, in train_minibatch_overload_for_minibatchdata return _cntk_py.Trainer_train_minibatch_overload_for_minibatchdata(self, args) RuntimeError: CUDA failure 2: out of memory ; GPU=0 ; hostname=aj-pc ; expr=cudaMalloc((void) &deviceBufferPtr, sizeof(AllocatedElemType) AsMultipleOf(numElements, 2)) [CALL STACK] [0x7f3da866aebc] + 0x5a3ebc [0x7f3da3c5cd43] + 0xb4fd43 [0x7f3da3c8f45d] float Microsoft::MSR::CNTK::TracingGPUMemoryAllocator:: Allocate (int, unsigned long, unsigned long) + 0x4d [0x7f3da3c8f75e] Microsoft::MSR::CNTK::GPUMatrix:: Resize (unsigned long, unsigned long, bool) + 0xee [0x7f3da3be7a25] Microsoft::MSR::CNTK::Matrix:: Resize (unsigned long, unsigned long, unsigned long, bool) + 0xb5 [0x7f3da899c001] Microsoft::MSR::CNTK::ComputationNode:: UpdateDataSize (Microsoft::MSR::CNTK::Matrix&) + 0x31 [0x7f3da899cbe3] Microsoft::MSR::CNTK::ComputationNode:: LazyZeroGradient (Microsoft::MSR::CNTK::ComputationNodeBase const*) + 0x143 [0x7f3da899e323] Microsoft::MSR::CNTK::ComputationNode:: Backprop (Microsoft::MSR::CNTK::FrameRange const&, bool, bool) + 0xf3 [0x7f3da8a28811] Microsoft::MSR::CNTK::ComputationNetwork::PARTraversalFlowControlNode:: Backprop (Microsoft::MSR::CNTK::FrameRange const&, bool, bool) + 0xb1 [0x7f3da8852cf3] CNTK::CompositeFunction:: Backward (std::shared_ptr const&, std::unordered_map<CNTK::Variable,std::shared_ptr,std::hash,std::equal_to,std::allocator<std::pair<CNTK::Variable const,std::shared_ptr>>> const&, std::unordered_map<CNTK::Variable,std::shared_ptr,std::hash,std::equal_to,std::allocator<std::pair<CNTK::Variable const,std::shared_ptr>>>&) + 0x293 [0x7f3da88d5cc7] CNTK::Trainer:: ExecuteForwardBackward (std::unordered_map<CNTK::Variable,std::shared_ptr,std::hash,std::equal_to,std::allocator<std::pair<CNTK::Variable const,std::shared_ptr>>> const&, std::unordered_map<CNTK::Variable,std::shared_ptr,std::hash,std::equal_to,std::allocator<std::pair<CNTK::Variable const,std::shared_ptr>>>&, CNTK::DeviceDescriptor const&, std::unordered_map<CNTK::Variable,std::shared_ptr,std::hash,std::equal_to,std::allocator<std::pair<CNTK::Variable const,std::shared_ptr>>>&) + 0x5d7 [0x7f3da88d6473] CNTK::Trainer:: TrainLocalMinibatch (std::unordered_map<CNTK::Variable,std::shared_ptr,std::hash,std::equal_to,std::allocator<std::pair<CNTK::Variable const,std::shared_ptr>>> const&, std::unordered_map<CNTK::Variable,std::shared_ptr,std::hash,std::equal_to,std::allocator<std::pair<CNTK::Variable const,std::shared_ptr>>>&, bool, CNTK::DeviceDescriptor const&) + 0xd3 [0x7f3da88d6efb] CNTK::Trainer:: TrainMinibatch (std::unordered_map<CNTK::Variable,CNTK::MinibatchData,std::hash,std::equal_to,std::allocator<std::pair<CNTK::Variable const,CNTK::MinibatchData>>> const&, std::unordered_map<CNTK::Variable,std::shared_ptr,std::hash,std::equal_to,std::allocator<std::pair<CNTK::Variable const,std::shared_ptr>>>&, CNTK::DeviceDescriptor const&) + 0x8b [0x7f3da88d70a0] CNTK::Trainer:: TrainMinibatch (std::unordered_map<CNTK::Variable,CNTK::MinibatchData,std::hash,std::equal_to,std::allocator<std::pair<CNTK::Variable const,CNTK::MinibatchData>>> const&, CNTK::DeviceDescriptor const&) + 0xa0 [0x7f3da93c11de] + 0x2651de [0x7f3db27465e9] PyCFunction_Call + 0xf9 [0x7f3db27cb7c0] PyEval_EvalFrameEx + 0x6ba0 [0x7f3db27ceb49] + 0x144b49 [0x7f3db27cddf5] PyEval_EvalFrameEx + 0x91d5 [0x7f3db27ceb49] + 0x144b49 [0x7f3db27cddf5] PyEval_EvalFrameEx + 0x91d5 [0x7f3db27ce166] PyEval_EvalFrameEx + 0x9546 [0x7f3db27ceb49] + 0x144b49 [0x7f3db27cecd8] PyEval_EvalCodeEx + 0x48 [0x7f3db27ced1b] PyEval_EvalCode + 0x3b [0x7f3db27f4020] PyRun_FileExFlags + 0x130 [0x7f3db27f5623] PyRun_SimpleFileExFlags + 0x173 [0x7f3db28108c7] Py_Main + 0xca7 [0x400add] main + 0x15d [0x7f3db17aff45] __libc_start_main + 0xf5 [0x4008b9]

best regards

mohamad-hasan-sohan-ajini commented 6 years ago

Also there is a weird behavior of loss before out of memory error!

screenshot-2017-11-26 tensorboard

The orange curve is network loss with cutoff=100 and the blue is with cutoff=50.