harvardnlp / pytorch-struct

Fast, general, and tested differentiable structured prediction in PyTorch
http://harvardnlp.github.io/pytorch-struct
MIT License
1.11k stars 93 forks source link

FastLogSemiring #108

Open w-cheng opened 3 years ago

w-cheng commented 3 years ago

Hi,

Thanks for making this library and it's amazing to have these different CRFs wrapped up in a common and easy to use framework.

I've been playing with the LinearChainCRF and one thing I noticed is the memory usage can be very high during loss backward pass on both CPU and GPU. I found the FastLogSemiring in fast_semirings.py uses genbmm.logbmm() and significantly reduce memory usage on GPU if I change the default LogSemiring used in StructDistribution class to FastLogSemiring. However, I haven't seen this being documented anywhere so my questions are:

  1. Is FastLogSemiring ready to be used? It's not being included in test_semirings.py
  2. If so, what would be the best way to switch between LogSemiring and FastLogSemiring? Is there a plan to introduce a parameter to choose between the semirings in StructDistribution class?
srush commented 3 years ago

Yes! It works and is heavily tested. We should make it default. It just requires the GPU kernels in genbmm be installed.

w-cheng commented 3 years ago

What do you think of performing a check of genbmm library in the imports like:

has_genbmm = False
try:
    import genbmm

    has_genbmm = True
    from .semirings import FastLogSemiring
except ImportError:
    pass

then a function in StructDistribution class:

    def default_log_semiring(self):
        return FastLogSemiring if has_genbmm and self.log_potentials.is_cuda else LogSemiring

So instead of return LogSemiring by default in marginals and partition property we call this default_log_semiring()

srush commented 3 years ago

yes, that would be great. You can do it for max too.