rikeda71 / TorchCRF

An Inplementation of CRF (Conditional Random Fields) in PyTorch 1.0
MIT License
135 stars 11 forks source link

Performance issue #3

Open andreabac3 opened 4 years ago

andreabac3 commented 4 years ago

I greatly appreciated your work, both for its simplicity of use and for your commitment. I'm probably wrong, but the library is very slow to use compared to other packages that do the same job.

I checked and all tensor operations are performed on the GPU (GTX 1070). The TQDM library estimates an iteration every two seconds during training but the waiting time is 2 hours per epoch. Using other libraries for the same model I get a waiting time of 15 minutes per epoch.

I can assure you that the mask, the CRF layer are run on GPU.

I also tried to force methods with to (device) but obviously nothing has changed. self.crflayer = CRF(hparams.num_classes, pad_idx=0).to(device) self.model.crflayer.forward(outputs, goldLabels, mask).to(device)

rikeda71 commented 4 years ago

Hi @andreabac3 . I ran speed test using TorchCRF (this repo) and pytorch-crf. Test code was shown in the following.

import torch
import cProfile

batch_size = 2
sequence_size = 3
num_labels = 5
labels = torch.LongTensor([[0, 2, 3], [1, 4, 1]]).cuda()  # (batch_size, sequence_size)
hidden = torch.randn((batch_size, sequence_size, num_labels), requires_grad=True).cuda()

from torchcrf import CRF
mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).cuda() # (batch_size. sequence_size)
def ossCRF(hidden, mask, labels):
    model = CRF(num_labels).cuda()
    for _ in range(1000):
        a = model(hidden, labels, mask=mask)
        a.mean().backward()

cProfile.run('ossCRF(hidden, mask, labels)')

from TorchCRF import CRF
mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]]).cuda() # (batch_size. sequence_size)
def myCRF(hidden, mask, labels):
    crf = CRF(num_labels)
    for _ in range(1000):
        a = crf(hidden, labels, mask)
        a.mean().backward()

cProfile.run('myCRF(hidden, mask, labels)')

The following is the result of the test.

TorchCRF (this repo)

        87164 function calls in 2.879 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    2.879    2.879 <string>:1(<module>)
     1000    0.197    0.000    0.574    0.001 __init__.py:144(_compute_denominator_log_likelihood)
     1000    0.259    0.000    0.703    0.001 __init__.py:193(_compute_numerator_log_likelihood)
     1000    0.004    0.000    0.020    0.000 __init__.py:21(_make_grads)
     2000    0.095    0.000    0.095    0.000 __init__.py:221(<listcomp>)
     2000    0.109    0.000    0.118    0.000 __init__.py:225(<listcomp>)
        1    0.000    0.000    0.000    0.000 __init__.py:249(_initialize_parameters)
     3000    0.071    0.000    0.246    0.000 __init__.py:266(logsumexp)
        3    0.000    0.000    0.000    0.000 __init__.py:280(myTensor)
     1000    0.015    0.000    1.293    0.001 __init__.py:41(forward)
     1000    0.005    0.000    1.523    0.002 __init__.py:45(backward)
        1    0.000    0.000    0.000    0.000 __init__.py:9(__init__)
        1    0.028    0.028    2.879    2.879 a.py:12(myCRF)
        3    0.000    0.000    0.000    0.000 grad_mode.py:151(__init__)
        3    0.000    0.000    0.000    0.000 grad_mode.py:65(__enter__)
        3    0.000    0.000    0.000    0.000 grad_mode.py:69(__exit__)
        3    0.000    0.000    0.000    0.000 init.py:12(_no_grad_uniform_)
        3    0.000    0.000    0.000    0.000 init.py:74(uniform_)
        3    0.000    0.000    0.000    0.000 module.py:138(register_parameter)
     1000    0.009    0.000    1.305    0.001 module.py:540(__call__)
     2003    0.003    0.000    0.003    0.000 module.py:580(__getattr__)
       17    0.000    0.000    0.000    0.000 module.py:596(__setattr__)
        3    0.000    0.000    0.000    0.000 module.py:597(remove_from)
        1    0.000    0.000    0.000    0.000 module.py:71(__init__)
        3    0.000    0.000    0.000    0.000 parameter.py:23(__new__)
     1000    0.011    0.000    1.534    0.002 tensor.py:170(backward)
     2000    0.006    0.000    0.012    0.000 tensor.py:454(__iter__)
     4000    0.009    0.000    0.009    0.000 tensor.py:468(<lambda>)
        3    0.000    0.000    0.000    0.000 {built-in method _make_subclass}
        1    0.000    0.000    2.879    2.879 {built-in method builtins.exec}
        3    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
     2037    0.001    0.000    0.001    0.000 {built-in method builtins.isinstance}
     2000    0.000    0.000    0.000    0.000 {built-in method builtins.iter}
     2000    0.001    0.000    0.001    0.000 {built-in method builtins.len}
     4000    0.133    0.000    0.133    0.000 {built-in method cat}
     3000    0.039    0.000    0.039    0.000 {built-in method exp}
     3000    0.036    0.000    0.036    0.000 {built-in method log}
     1000    0.015    0.000    0.015    0.000 {built-in method ones_like}
     3000    0.042    0.000    0.042    0.000 {built-in method sum}
     3000    0.005    0.000    0.005    0.000 {built-in method torch._C._get_tracing_state}
        1    0.000    0.000    0.000    0.000 {built-in method torch._C._log_api_usage_once}
        6    0.000    0.000    0.000    0.000 {built-in method torch._C.is_grad_enabled}
        6    0.000    0.000    0.000    0.000 {built-in method torch._C.set_grad_enabled}
     1000    0.000    0.000    0.000    0.000 {method 'append' of 'list' objects}
        3    0.000    0.000    0.000    0.000 {method 'cuda' of 'torch._C._TensorBase' objects}
     2000    0.001    0.000    0.001    0.000 {method 'dim' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        3    0.000    0.000    0.000    0.000 {method 'format' of 'str' objects}
     2000    0.026    0.000    0.026    0.000 {method 'gather' of 'torch._C._TensorBase' objects}
       45    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
     1000    0.012    0.000    0.012    0.000 {method 'long' of 'torch._C._TensorBase' objects}
     3000    0.047    0.000    0.047    0.000 {method 'max' of 'torch._C._TensorBase' objects}
     1000    0.012    0.000    0.012    0.000 {method 'mean' of 'torch._C._TensorBase' objects}
     1000    0.000    0.000    0.000    0.000 {method 'numel' of 'torch._C._TensorBase' objects}
     1000    1.498    0.001    1.498    0.001 {method 'run_backward' of 'torch._C._EngineBase' objects}
     4000    0.004    0.000    0.004    0.000 {method 'size' of 'torch._C._TensorBase' objects}
     1000    0.005    0.000    0.005    0.000 {method 'squeeze' of 'torch._C._TensorBase' objects}
     1000    0.013    0.000    0.013    0.000 {method 'sum' of 'torch._C._TensorBase' objects}
     2000    0.029    0.000    0.029    0.000 {method 'to' of 'torch._C._TensorBase' objects}
     2000    0.075    0.000    0.075    0.000 {method 'type' of 'torch._C._TensorBase' objects}
        3    0.000    0.000    0.000    0.000 {method 'uniform_' of 'torch._C._TensorBase' objects}
     7000    0.032    0.000    0.032    0.000 {method 'unsqueeze' of 'torch._C._TensorBase' objects}
     2000    0.001    0.000    0.001    0.000 {method 'values' of 'collections.OrderedDict' objects}
    11000    0.032    0.000    0.032    0.000 {method 'view' of 'torch._C._TensorBase' objects}

pytorch-crf

        51195 function calls in 1.667 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    1.667    1.667 <string>:1(<module>)
        3    0.000    0.000    0.000    0.000 __future__.py:18(get_overwrite_module_params_on_conversion)
     1000    0.025    0.000    0.039    0.000 __init__.py:141(_validate)
     1000    0.261    0.000    0.336    0.000 __init__.py:169(_compute_score)
     1000    0.087    0.000    0.302    0.000 __init__.py:208(_compute_normalizer)
     1000    0.004    0.000    0.019    0.000 __init__.py:21(_make_grads)
        1    0.000    0.000    0.000    0.000 __init__.py:38(__init__)
     1000    0.004    0.000    0.926    0.001 __init__.py:45(backward)
        1    0.000    0.000    0.000    0.000 __init__.py:50(reset_parameters)
     1000    0.017    0.000    0.704    0.001 __init__.py:63(forward)
        1    0.013    0.013    1.667    1.667 a.py:23(ossCRF)
        6    0.000    0.000    0.000    0.000 grad_mode.py:151(__init__)
        6    0.000    0.000    0.000    0.000 grad_mode.py:65(__enter__)
        6    0.000    0.000    0.000    0.000 grad_mode.py:69(__exit__)
        3    0.000    0.000    0.000    0.000 init.py:12(_no_grad_uniform_)
        3    0.000    0.000    0.000    0.000 init.py:74(uniform_)
        3    0.000    0.000    0.000    0.000 module.py:138(register_parameter)
        1    0.000    0.000    0.000    0.000 module.py:201(_apply)
        3    0.000    0.000    0.000    0.000 module.py:205(compute_should_use_set_data)
        1    0.000    0.000    0.000    0.000 module.py:293(cuda)
        3    0.000    0.000    0.000    0.000 module.py:307(<lambda>)
     1000    0.006    0.000    0.712    0.001 module.py:540(__call__)
     6006    0.004    0.000    0.004    0.000 module.py:580(__getattr__)
       14    0.000    0.000    0.000    0.000 module.py:596(__setattr__)
        3    0.000    0.000    0.000    0.000 module.py:597(remove_from)
        1    0.000    0.000    0.000    0.000 module.py:71(__init__)
        1    0.000    0.000    0.000    0.000 module.py:961(children)
        1    0.000    0.000    0.000    0.000 module.py:970(named_children)
        3    0.000    0.000    0.000    0.000 parameter.py:23(__new__)
     1000    0.003    0.000    0.929    0.001 tensor.py:170(backward)
        3    0.000    0.000    0.000    0.000 tensor.py:737(grad)
        3    0.000    0.000    0.000    0.000 {built-in method _has_compatible_shallow_copy_type}
        3    0.000    0.000    0.000    0.000 {built-in method _make_subclass}
     3000    0.028    0.000    0.028    0.000 {built-in method arange}
        1    0.000    0.000    1.667    1.667 {built-in method builtins.exec}
        6    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
     2031    0.001    0.000    0.001    0.000 {built-in method builtins.isinstance}
     2000    0.000    0.000    0.000    0.000 {built-in method builtins.len}
        3    0.000    0.000    0.000    0.000 {built-in method empty}
     2000    0.170    0.000    0.170    0.000 {built-in method logsumexp}
     1000    0.014    0.000    0.014    0.000 {built-in method ones_like}
     1000    0.002    0.000    0.002    0.000 {built-in method torch._C._get_tracing_state}
        1    0.000    0.000    0.000    0.000 {built-in method torch._C._log_api_usage_once}
       12    0.000    0.000    0.000    0.000 {built-in method torch._C.is_grad_enabled}
       12    0.000    0.000    0.000    0.000 {built-in method torch._C.set_grad_enabled}
     1000    0.021    0.000    0.021    0.000 {built-in method where}
     3000    0.032    0.000    0.032    0.000 {method 'all' of 'torch._C._TensorBase' objects}
     1000    0.000    0.000    0.000    0.000 {method 'append' of 'list' objects}
        3    0.000    0.000    0.000    0.000 {method 'cuda' of 'torch._C._TensorBase' objects}
     5000    0.001    0.000    0.001    0.000 {method 'dim' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
     1000    0.013    0.000    0.013    0.000 {method 'float' of 'torch._C._TensorBase' objects}
        3    0.000    0.000    0.000    0.000 {method 'format' of 'str' objects}
       36    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
        3    0.000    0.000    0.000    0.000 {method 'items' of 'collections.OrderedDict' objects}
     1000    0.010    0.000    0.010    0.000 {method 'long' of 'torch._C._TensorBase' objects}
     1000    0.012    0.000    0.012    0.000 {method 'mean' of 'torch._C._TensorBase' objects}
     1000    0.000    0.000    0.000    0.000 {method 'numel' of 'torch._C._TensorBase' objects}
     1000    0.903    0.001    0.903    0.001 {method 'run_backward' of 'torch._C._EngineBase' objects}
     4000    0.003    0.000    0.003    0.000 {method 'size' of 'torch._C._TensorBase' objects}
     2000    0.022    0.000    0.022    0.000 {method 'sum' of 'torch._C._TensorBase' objects}
        3    0.000    0.000    0.000    0.000 {method 'uniform_' of 'torch._C._TensorBase' objects}
     3000    0.010    0.000    0.010    0.000 {method 'unsqueeze' of 'torch._C._TensorBase' objects}
     2000    0.000    0.000    0.000    0.000 {method 'values' of 'collections.OrderedDict' objects}

The two results show that TorchCRF has more function calls and run_backward takes longer than pytorch-crf.

I don' know why run_backward is taking too long, and how to speed up run_backward...

TorchCRF calls many matrix transformation function (view, squeeze, and unsqueeze), this may be bad.

rikeda71 commented 4 years ago

Hi @andreabac3 I fixed algorithms to improve performance in #10. If you'd like, you can try out the performance evaluation. Thank you.