kordk / torch-ecpg

(GPU accelerated) eCpG mapper
BSD 3-Clause "New" or "Revised" License
2 stars 0 forks source link

Regression Optimization #17

Closed liamgd closed 1 year ago

liamgd commented 1 year ago

Optimize the multiple linear regression to reduce algorithm run time and to utilize the CUDA device better.

liamgd commented 1 year ago

Speed tests (excluding data saving) after commit

Data: 300 samples, 10,000 gt rows, 10,000 mt rows

Mode Constant Calculation (seconds) Loop over methylation loci (seconds)
CUDA 1.3899 22.233
CPU 1 thread 0.0876 216.3914
liamgd commented 1 year ago

Testing regression_full with line_profiler yields the following timings for 100 samples, 1000 methylation loci, and 1000 gene expression loci without chunking or p-value filtration. The size of the input gives a variety of different analysis results, and p filtration would reduce the creation of the output dataframe which is a major consumer of time.

Output file from kernprof ``` [1000000 rows x 4 columns] Wrote profile results to regression_full.py.lprof Timer unit: 1e-06 s Total time: 57.6697 s File: tecpg/regression_full.py Function: regression_full at line 15 Line # Hits Time Per Hit % Time Line Contents ============================================================== 15 @profile 16 def regression_full( 17 M: pandas.DataFrame, 18 G: pandas.DataFrame, 19 C: pandas.DataFrame, 20 loci_per_chunk: Optional[int] = None, 21 p_thresh: Optional[float] = None, 22 output_dir: Optional[str] = None, 23 *, 24 logger: Logger = Logger(), 25 ) -> Optional[pandas.DataFrame]: 26 1 0.9 0.9 0.0 if (output_dir is None) != (loci_per_chunk is None): 27 error = 'Output dir and chunk size must be defined together.' 28 logger.error(error) 29 raise ValueError(error) 30 31 1 29.0 29.0 0.0 logger.info('Initializing regression variables') 32 1 10657.7 10657.7 0.0 device = get_device(**logger) 33 1 1.2 1.2 0.0 dtype = torch.float32 34 1 15.7 15.7 0.0 nrows, ncols = C.shape[0], C.shape[1] + 1 35 1 4.9 4.9 0.0 gt_count, mt_count = len(G), len(M) 36 1 13.5 13.5 0.0 mt_site_names = numpy.array(M.index.values) 37 1 6.5 6.5 0.0 gt_site_names = numpy.array(G.index.values) 38 1 0.4 0.4 0.0 df = nrows - ncols - 1 39 1 12.7 12.7 0.0 logger.info('Running with {0} degrees of freedom', df) 40 1 859012.6 859012.6 1.5 dft_sqrt = torch.tensor(df, device=device, dtype=dtype).sqrt() 41 1 379.2 379.2 0.0 log_prob = torch.distributions.studentT.StudentT(df).log_prob 42 1 40.6 40.6 0.0 M_np = M.to_numpy() 43 1 13.3 13.3 0.0 columns = ['const_p', 'mt_p'] + [val + '_p' for val in C.columns] 44 1 0.2 0.2 0.0 last_index = 0 45 1 0.2 0.2 0.0 results = [] 46 1 0.2 0.2 0.0 if loci_per_chunk: 47 chunk_count = math.ceil(len(M) / loci_per_chunk) 48 logger.info('Initializing output directory') 49 initialize_dir(output_dir, **logger) 50 1 0.3 0.3 0.0 if p_thresh is not None: 51 output_sizes = [] 52 indices_list = [] 53 54 1 28.3 28.3 0.0 logger.start_timer('info', 'Running regression_full...') 55 1 129.0 129.0 0.0 Ct: torch.Tensor = torch.tensor( 56 1 195.7 195.7 0.0 C.to_numpy(), device=device, dtype=dtype 57 1 75.6 75.6 0.0 ).repeat(gt_count, 1, 1) 58 1 44.8 44.8 0.0 logger.time('Converted C to tensor in {l} seconds') 59 1 1073.3 1073.3 0.0 Gt: torch.Tensor = torch.tensor( 60 1 18.1 18.1 0.0 G.to_numpy(), device=device, dtype=dtype 61 1 28.8 28.8 0.0 ).unsqueeze(2) 62 1 30.7 30.7 0.0 logger.time('Converted G to tensor in {l} seconds') 63 1 65.7 65.7 0.0 ones = torch.ones((gt_count, nrows, 1), device=device, dtype=dtype) 64 1 29.4 29.4 0.0 logger.time('Created ones in {l} seconds') 65 1 214.4 214.4 0.0 X: torch.Tensor = torch.cat((ones, Gt, Ct), 2) 66 1 23.1 23.1 0.0 logger.time('Created X in {l} seconds') 67 1 19.3 19.3 0.0 Xt = X.mT 68 1 16.9 16.9 0.0 logger.time('Transposed X in {l} seconds') 69 1 591599.1 591599.1 1.0 XtXi = Xt.bmm(X).inverse() 70 1 51.9 51.9 0.0 logger.time('Calculated XtXi in {l} seconds') 71 1 71.5 71.5 0.0 XtXi_diag_sqrt = torch.diagonal(XtXi, dim1=1, dim2=2).sqrt() 72 1 31.5 31.5 0.0 logger.time('Calculated XtXi_diag in {l} seconds') 73 1 46.5 46.5 0.0 XtXi_Xt = XtXi.bmm(Xt) 74 1 21.2 21.2 0.0 logger.time('Calculated XtXi_Xt in {l} seconds') 75 1 15.7 15.7 0.0 logger.time('Calculated X constants in {t} seconds') 76 1 48421.9 48421.9 0.1 with Pool() as pool: 77 1000 1384.7 1.4 0.0 for index, M_row in enumerate(M_np, 1): 78 1000 43123.4 43.1 0.1 Y = torch.tensor(M_row, device=device, dtype=dtype) 79 1000 34386.2 34.4 0.1 B = XtXi_Xt.matmul(Y) 80 1000 56001.9 56.0 0.1 E = (Y.unsqueeze(1) - X.bmm(B.unsqueeze(2))).squeeze(2) 81 1000 81465.1 81.5 0.1 scalars = (torch.sum(E * E, 1)).view((-1, 1)).sqrt() / dft_sqrt 82 1000 18743.4 18.7 0.0 S = XtXi_diag_sqrt * scalars 83 1000 16048.1 16.0 0.0 T = B / S 84 1000 294608.8 294.6 0.5 P = torch.exp(log_prob(T)) 85 1000 514.9 0.5 0.0 if p_thresh is None: 86 1000 792.4 0.8 0.0 results.append(P) 87 else: 88 indices = P[:, 1] >= p_thresh 89 output_sizes.append(indices.count_nonzero().item()) 90 indices_list.extend(indices.cpu()) 91 results.append(P[indices]) 92 1000 574.0 0.6 0.0 if loci_per_chunk and ( 93 index % loci_per_chunk == 0 or index == mt_count 94 ): 95 mt_site_name_chunk = mt_site_names[last_index:index] 96 last_index = index 97 if p_thresh is None: 98 mt_sites = mt_site_name_chunk.repeat(gt_count) 99 gt_sites = numpy.tile(gt_site_names, len(results)) 100 else: 101 mt_sites = mt_site_name_chunk.repeat(output_sizes) 102 mask = numpy.array(indices_list, dtype=bool) 103 gt_sites = gt_site_names.repeat(len(results))[mask] 104 index_chunk = [gt_sites, mt_sites] 105 106 file_name = str(logger.current_count + 1) + '.csv' 107 file_path = os.path.join(output_dir, file_name) 108 out = pandas.DataFrame( 109 torch.cat(results), 110 index=index_chunk, 111 columns=columns, 112 ).astype(float) 113 logger.count( 114 'Saving part {i}/{0}:', 115 chunk_count, 116 ) 117 pool.apply_async( 118 save_dataframe_part, 119 (out, file_path, logger.current_count), 120 dict(logger), 121 ) 122 results.clear() 123 output_sizes.clear() 124 indices_list.clear() 125 126 1 39.2 39.2 0.0 logger.time('Looped over methylation loci in {l} seconds') 127 1 16.8 16.8 0.0 logger.time('Calculated regression_full in {t} seconds') 128 129 1 0.4 0.4 0.0 if loci_per_chunk: 130 logger.time('Waiting for chunks to save...') 131 pool.close() 132 pool.join() 133 logger.time('Finished waiting for chunks to save in {l} seconds') 134 return 135 136 1 12.2 12.2 0.0 logger.start_timer('info', 'Generating dataframe from results...') 137 1 0.3 0.3 0.0 if p_thresh is None: 138 1 3726.5 3726.5 0.0 mt_sites = mt_site_names.repeat(gt_count) 139 1 2045.5 2045.5 0.0 gt_sites = numpy.tile(gt_site_names, len(results)) 140 else: 141 mt_sites = mt_site_names.repeat(output_sizes) 142 mask = numpy.array(indices_list, dtype=bool) 143 gt_sites = gt_site_names.repeat(len(results))[mask] 144 1 0.6 0.6 0.0 index_chunk = [gt_sites, mt_sites] 145 1 41.2 41.2 0.0 logger.time('Finished creating indices in {l} seconds') 146 1 15977555.2 15977555.2 27.7 out = pandas.DataFrame( 147 1 355.2 355.2 0.0 torch.cat(results), 148 1 0.4 0.4 0.0 index=index_chunk, 149 1 0.2 0.2 0.0 columns=columns, 150 1 39625769.8 39625769.8 68.7 ).astype(float) 151 1 69.6 69.6 0.0 logger.time('Finished creating preliminary dataframe in {l} seconds') 152 1 26.9 26.9 0.0 logger.time('Created output dataframe in {t} total seconds') 153 1 0.3 0.3 0.0 return out ```
kordk commented 1 year ago

By "The size of the input gives a variety of different analysis results", do you mean differences in performance (e.g., time to completion) and resources (e.g., memory or output file size) or the results of the analyses themselves (e.g., statistic and p-values)?

liamgd commented 1 year ago

By that, I mean the ratio between the times spent per operation change a lot depending on the sample count and loci counts. For example, smaller inputs spend more time setting up the CUDA device and kernels whereas a large input with strict p-value filtration results in more relative time spent creating the filtration indices and mask.

liamgd commented 1 year ago

Tests of regression_full.py with dummy data of 300 samples, 10,000 methylation loci, and 10,000 gene expression loci, no chunking, and p-thresh of 0.4 (filters all results so none are saved). About 1/202.195 of the GTP dataset.

CPU (i7-7700k with 4 threads) - 120.8198 s GPU (RTX 2070 Super) - 23.699 s

For small inputs, the CPU tends to perform better than the GPU, but at a certain point, the GPU become highly superior. I suspect that the after a breakeven point, the GPU becomes faster and faster relative to the CPU as the input size increases.

This test was for the computation only, and not the saving, as the p-threshold of 0.4 is higher than any mt p-value. It seems like the maximum mt p-value is around 3.9.

kordk commented 1 year ago

Thanks for the clarification. Once we've stabilized the code for the mlr and evaluated the reproducibility of the Kennedy 2018 analysis we'll evaluate the performance. This work will include an evaluation of the scaling performance (i.e., how fast the analyses completes with different numbers of samples and loci).

Is the new mlr code now the default?

liamgd commented 1 year ago

It is not yet the default as it does not have region mapping. After that is complete, it will replace tecpg run mlr.

liamgd commented 1 year ago

Now that output inclusion is controlled in regression_full (from the last few commits), it should have all of the major functionality of regresion_single, and it can replace the tecpg run mlr command. Here are the minor changes as a result of this:

  1. The tecpg run mlr-full is now tecpg run mlr
  2. The tecpg run mlr command is now tecpg run mlr-single
  3. Instead of using --no-est, --no-err, --no-t, and --no-p for controlling the type of regression results to be included in the output of tecpg run mlr, use --p-only or -P to only include p-values in the output. Otherwise, all regression result types will be included.
  4. Instead of using --regressions-per-chunk or -r to control the number of regressions per chunk, use -l or --loci-per-chunk to control the number of methylation loci included per chunk. For each methylation locus, all gene expression loci will be compared. For $l$ loci per chunk with $g$ total gene expression loci, $l \times g$ regressions will run. This is because regression_full operates on each methylation locus with all of the gene expression loci in parallel.
liamgd commented 1 year ago

Regression_full is now the default as of 986619e.

kordk commented 1 year ago

Excellent!

kordk commented 1 year ago

Given that users will have different GPU (and CPU) environments, and that chunking will essential for the user to run the analyses, can we provide guidelines for selecting the chunking size based on any one or combination of characteristics (e.g., number of samples, number of methylation loci, number of gx loci)? We don't need to be precise, rather just a general suggestion as a place for them to start and adapt to their own dataset.

liamgd commented 1 year ago

Yes. The most important purpose of chunking is to avoid running out of memory. On CUDA GPUs, torch raises RuntimeError: CUDA out of memory if each chunk is too large.

One potential difficulty is that memory, either on the GPU or CPU, is used by other programs and the memory available to torch could fluctuate during computation. To get around this, the user could input the target memory usage of tecpg, which would allow for a lot of control if memory is limited and used by other programs. Another difficulty is that memory usage could be very unpredictable when using region filtration. For example, a cis analysis with a small window can have larger chunks before running into memory limits, as more values are filtered away. This would be quite hard to predict, and the same goes for a p-value threshold. Given $b$ bytes per value of dtype (for the default, torch.float32 and torch.int, this is 4 bytes each), $s$ samples, $c$ covariates, $m$ methylation loci, and $g$ gene expression loci: Constants: ```python dft_sqrt = torch.tensor(df, device=device, dtype=dtype).sqrt() # b G_chrom_t = torch.tensor(G_chrom, device=device, dtype=torch.int) # g * b G_pos_t = torch.tensor(G_pos, device=device, dtype=torch.int) # g * b M_chrom_t = torch.tensor(M_chrom, device=device, dtype=torch.int) # m * b M_pos_t = torch.tensor(M_pos, device=device, dtype=torch.int) # m * b Ct: torch.Tensor = torch.tensor( C.to_numpy(), device=device, dtype=dtype ).repeat(mt_count, 1, 1) # c * m * s * b Mt: torch.Tensor = torch.tensor( M.to_numpy(), device=device, dtype=dtype ).unsqueeze(2) # m * s * b ones = torch.ones((mt_count, nrows, 1), device=device, dtype=dtype) # m * s * b X: torch.Tensor = torch.cat((ones, Mt, Ct), 2) # (2 + c) * m * s * b del Mt, Ct, ones # -(2 + c) * m * s * b Xt = X.mT # (2 + c) * m * s * b XtXi = Xt.bmm(X).inverse() # (2 + c) ^ 2 * m * b XtXi_diag_sqrt = torch.diagonal(XtXi, dim1=1, dim2=2).sqrt() # (2 + c) * m * b XtXi_Xt = XtXi.bmm(Xt) # (2 + c) * m * s * b del Xt, XtXi # -(2 + c) * m * b * (s - 2 - c)) ``` Approximate bytes allocated from constants: $b+(2+c)mb+2msb(c+2)$. If region filtration, $+2b(m+g)$ This is extremely accurate. Testing this formula against the real memory allocated for the constants with the GTP dataset resulted in a 0.0674825064989% error percentage. During the regressions for each gene expression loci, some tensors are garbage collected and others are saved for later when they are going to be outputted. The peak torch memory usage during computation is either when the results list is concatenated before saving or when the scalars tensor is calculated, depending on the other variables, especially the number of samples. With $l$ loci per chunk and $f\in [0, 1]$, a float representing what portion of the tensors remain after p-value and region filtration: If the peak memory usage is after scalars is created: ```python Y # s * b if region_filtration: region_indices # m * l B # f * m * b * (1 if p_only else l) E # f * m * s * b scalars # f * m * b if not p_only: T # f * m * b * (l - 1) * (2 + c if full_output else 1) S # f * m * b * (l - 1) * (2 + c if full_output else 1) P # f * m * b * (l - 1) * (2 + c if full_output else 1) if p_filtration: p_indices # f * m * (l - 1) ``` This separates into two parts: chunk constants, $c$, and bytes per locus, $L$. In total, with target memory $T$ and non-chunk constants $C$: $T=C+c+L\times (l-1)$ $T-C-c=L\times (l-1)$ $l-1=\frac{T-C-c}{L}$ $l=\frac{T-C-c}{L}+1$ If the peak memory usage is when the results list is calculated: ```python if not p_only: B # 2 * f * m * b * l * (2 + c if full_output else 1) T # 2 * f * m * b * l * (2 + c if full_output else 1) S # 2 * f * m * b * l * (2 + c if full_output else 1) P # 2 * f * m * b * l * (2 + c if full_output else 1) ``` The $2\times$ is for the redundancy after concatenating the results list. In total, with target memory $T$, constants $C$, and bytes per locus $L$: $T=C+L\times l$ $T-C=L\times l$ $l=\frac{T-C}{L}$
Final functions ```python def estimate_loci_per_chunk_e_peak( target_bytes: int, samples: int, mt_count: int, gt_count: int, covar_count: int = 2, datum_bytes: int = 4, filtration: float = 1, full_output: bool = False, p_only: bool = True, p_filtration: bool = False, region_filtration: bool = False, ) -> float: constants_bytes = estimate_constants_bytes( samples, mt_count, gt_count, covar_count, datum_bytes, region_filtration, ) chunk_constants = ( filtration * mt_count * samples * datum_bytes + 2 * filtration * mt_count * datum_bytes ) if region_filtration: chunk_constants += mt_count locus_bytes = filtration * mt_count * datum_bytes if not p_only: locus_bytes *= 4 if full_output: locus_bytes *= 2 + covar_count if region_filtration: locus_bytes += mt_count if p_filtration: locus_bytes += filtration * mt_count e_loci_per_chunk = ( target_bytes - constants_bytes - chunk_constants ) / locus_bytes + 1 return e_loci_per_chunk def estimate_loci_per_chunk_results_peak( target_bytes: int, samples: int, mt_count: int, gt_count: int, covar_count: int = 2, datum_bytes: int = 4, filtration: float = 1, full_output: bool = False, p_only: bool = True, region_filtration: bool = False, ) -> float: constants_bytes = estimate_constants_bytes( samples, mt_count, gt_count, covar_count, datum_bytes, region_filtration, ) locus_bytes = 2 * filtration * mt_count * datum_bytes if not p_only: locus_bytes *= 4 if full_output: locus_bytes *= 2 + covar_count results_loci_per_chunk = (target_bytes - constants_bytes) / locus_bytes return results_loci_per_chunk def estimate_constants_bytes( samples: int, mt_count: int, gt_count: int, covar_count: int = 2, datum_bytes: int = 4, region_filtration: bool = False, ) -> int: constants_bytes = ( datum_bytes + (2 + covar_count) * mt_count * datum_bytes + 2 * mt_count * samples * datum_bytes * (covar_count + 2) ) if region_filtration: constants_bytes += 2 * datum_bytes * mt_count * gt_count return constants_bytes ```

The limitations of this algorithm include that the user needs to provide the filtration coefficient and the algorithm does not account for CPU memory usage outside of torch (such as from numpy or pandas).

liamgd commented 1 year ago

Run tecpg chunks to get the maximum loci per chunk for a given target torch memory usage (default 80% of total memory). Use --filtration to specify what portion of data remains after region and p-value filtration or use -r false -p false to specify that no region filtration or p value filtration will occur. Use -s [samples] -m [mt_count] -g [gt_count] -c 2 to estimate with that input size or omit these options to use the size of the data in the current working directory. Use -f [true/false] and -P [true/false] to filter what output modes are enabled (full output and p-value only, respectively). Omit these options to show estimates for all combinations of output modes.