kordk / torch-ecpg

(GPU accelerated) eCpG mapper
BSD 3-Clause "New" or "Revised" License
2 stars 0 forks source link

Optimize probability function #24

Closed liamgd closed 1 year ago

liamgd commented 1 year ago

Currently, the regression algorithms use torch.exp of torch.distributions.studentT.StudentT(df).log_prob. This could be improved.

liamgd commented 1 year ago

First try:

def prob_create(df: int, device: torch.device, dtype: torch.dtype):
    @torch.jit.script
    def prob(value: torch.Tensor, dft, scalar, exponent):
        return (value ** 2.0 / dft + 1) ** exponent / scalar ** exponent

    dft = torch.tensor(df, device=device, dtype=dtype)
    scalar = torch.tensor(
        math.sqrt(df)
        * math.sqrt(math.pi)
        * math.exp(math.lgamma(0.5 * df) - math.lgamma(0.5 * df + 0.5)),
        device=device,
        dtype=dtype,
    )
    exponent = torch.tensor(-0.5 * (df + 1.0), device=device, dtype=dtype)

    return lambda value: prob(value, dft, scalar, exponent)

Computation overflows and results in a tensor of infinities.

liamgd commented 1 year ago

One equivalent algorithm:

@torch.jit.script
class Prob:
    def __init__(
        self, df: int, device: torch.device, dtype: torch.dtype
    ) -> None:
        self.df = df
        self.offset = torch.tensor(
            -0.5 * math.log(df)
            - 0.5 * math.log(math.pi)
            - math.lgamma(0.5 * df)
            + math.lgamma(0.5 * (df + 1.0)),
            device=device,
            dtype=dtype,
        )
        self.scalar = torch.tensor(
            0.5 * (self.df + 1.0), device=device, dtype=dtype
        )

    def prob(self, value: torch.Tensor):
        return (
            self.offset - torch.log1p(value ** 2.0 / self.df) * self.scalar
        ).exp()

For torch just in time compilation using the @torch.jit.script decorator, a class must be used to accept a variable df.

liamgd commented 1 year ago

Kernprof tests without JIT compilation

(because torch JIT does not support running with line_profiler)

Comparing a torch.distributions.studentT.StudentT(df).log_prob(value).exp() with the new Prob algorithm over 100 repeats for accuracy. Test data: torch.random 100,000,000 by 4 tensor.

Total time: 13.1775 s
File: .\tecpg\regression_full.py
Function: test_prob at line 259

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   259                                           @profile
   260                                           def test_prob() -> None:
   261         1          6.9      6.9      0.0      torch.cuda.empty_cache()
   262         1          0.2      0.2      0.0      df = 200
   263         1         12.2     12.2      0.0      device = torch.device('cuda')
   264         1          0.7      0.7      0.0      dtype = torch.float32
   265         1          0.5      0.5      0.0      prob_one = lambda t: torch.distributions.StudentT(df).log_prob(t).exp()
   266         1      63757.1  63757.1      0.5      prob_two = Prob(df, device, dtype).prob
   267         1     848254.1 848254.1      6.4      test = torch.rand((100_000_000, 4), device=device, dtype=dtype)
   268         1        311.1    311.1      0.0      print(torch.cuda.memory_allocated())
   269       100         63.3      0.6      0.0      for _ in range(100):
   270       100   12254983.8 122549.8     93.0          prob_one(test)
   271       100      10153.0    101.5      0.1          prob_two(test)

With torch.distributions.studentT.StudentT(df).log_prob(value).exp(): 122549.8 μs average per loop With the new Prob algorithm, not JIT: 101.5 average per loop

The new algorithm is approximately 1207.39 times faster from this test.

liamgd commented 1 year ago

Tests of combinations of region filtration, p-value filtration, and chunking with and without @torch.jit.script on the probability function. 300 samples, 50,000 methylation loci, and 10,000 gene expression loci on a 2070 super with 8 GB VRAM.

The first number after each command is the time without torch JIT, and the second is with torch JIT.

tecpg run mlr-full: *Too much GPU memory allocation*
tecpg run mlr-full --cis: 93.1929 to 98.289 seconds
tecpg run mlr-full --distal: 94.3363 to 98.0763 seconds
tecpg run mlr-full --trans: *Too much GPU memory allocation*

tecpg run mlr-full -p 0.05: 93.2036 to 95.8579 seconds
tecpg run mlr-full -p 0.05 --cis: 93.427 to 97.4576 seconds
tecpg run mlr-full -p 0.05 --distal: 93.5236 to 97.2995 seconds
tecpg run mlr-full -p 0.05 --trans: 93.5954 to 96.8116 seconds

tecpg run mlr-full -l 2000: 169.2086 to 169.5929 seconds
tecpg run mlr-full -l 2000 --cis: 102.665 to 101.5853 seconds
tecpg run mlr-full -l 2000 --distal: 104.6743 to 107.7221 seconds
tecpg run mlr-full -l 2000 --trans: 180.5167 to 175.9578 seconds

tecpg run mlr-full -l 2000 -p 0.05: 101.5071 to 102.3017 seconds
tecpg run mlr-full -l 2000 -p 0.05 --cis: 97.5001 to 97.8865 seconds
tecpg run mlr-full -l 2000 -p 0.05 --distal: 98.7852 to 98.6476 seconds
tecpg run mlr-full -l 2000 -p 0.05 --trans: 101.6967 to 100.2482 seconds

It seems that in general, for this input size, torch JIT does not have a significant positive impact on performance and instead offers a slight decrease in speed.

liamgd commented 1 year ago

Tests with 300 samples, 5,000 methylation loci, and 5,000 gene expression loci.

Kernprof of class-based constant folding: 0.0001124 seconds initialization, 0.922004 seconds computation ```python Total time: 0.0001124 s File: .\tecpg\regression_full.py Function: __init__ at line 17 Line # Hits Time Per Hit % Time Line Contents ============================================================== 17 @profile 18 def __init__( 19 self, df: int, device: torch.device, dtype: torch.dtype 20 ) -> None: 21 1 1.1 1.1 1.0 self.df = df 22 1 69.2 69.2 61.6 self.offset = torch.tensor( 23 1 4.5 4.5 4.0 -0.5 * math.log(df) 24 1 1.3 1.3 1.2 - 0.5 * math.log(math.pi) 25 1 4.2 4.2 3.7 - math.lgamma(0.5 * df) 26 1 0.9 0.9 0.8 + math.lgamma(0.5 * (df + 1.0)), 27 1 0.2 0.2 0.2 device=device, 28 1 0.2 0.2 0.2 dtype=dtype, 29 ) 30 1 30.1 30.1 26.8 self.scalar = torch.tensor( 31 1 0.7 0.7 0.6 0.5 * (self.df + 1.0), device=device, dtype=dtype 32 ) Total time: 0.922004 s File: .\tecpg\regression_full.py Function: prob at line 34 Line # Hits Time Per Hit % Time Line Contents ============================================================== 34 @profile 35 def prob(self, value: torch.Tensor): 36 10000 2868.1 0.3 0.3 return ( 37 10000 792448.5 79.2 85.9 self.offset - torch.log1p(value ** 2.0 / self.df) * self.scalar 38 10000 126687.0 12.7 13.7 ).exp() ```
Kernprof of closure-based constant folding: 0.0001542 seconds intialization, 0.888699 seconds computation ```python Total time: 0.0001542 s File: .\tecpg\regression_full.py Function: create_prob at line 40 Line # Hits Time Per Hit % Time Line Contents ============================================================== 40 @profile 41 def create_prob(df: int, device: torch.device, dtype: torch.dtype): 42 1 65.8 65.8 42.7 offset = torch.tensor( 43 1 3.4 3.4 2.2 -0.5 * math.log(df) 44 1 1.5 1.5 1.0 - 0.5 * math.log(math.pi) 45 1 3.8 3.8 2.5 - math.lgamma(0.5 * df) 46 1 1.2 1.2 0.8 + math.lgamma(0.5 * (df + 1.0)), 47 1 0.3 0.3 0.2 device=device, 48 1 0.2 0.2 0.1 dtype=dtype, 49 ) 50 1 29.5 29.5 19.1 scalar = torch.tensor(0.5 * (df + 1.0), device=device, dtype=dtype) 51 52 1 0.7 0.7 0.5 @profile 53 1 47.5 47.5 30.8 def prob(value: torch.Tensor): 54 return (offset - torch.log1p(value ** 2.0 / df) * scalar).exp() 55 56 1 0.3 0.3 0.2 return prob Total time: 0.888699 s File: .\tecpg\regression_full.py Function: prob at line 52 Line # Hits Time Per Hit % Time Line Contents ============================================================== 52 @profile 53 def prob(value: torch.Tensor): 54 10000 888699.1 88.9 100.0 return (offset - torch.log1p(value ** 2.0 / df) * scalar).exp() ```

The closure approach is incompatible with torch JIT, but from previous tests, JIT seems to slightly decrease performance. From the above tests, it is apparent that the closure solution is around 4% faster. Implemented in 9dae7ff.