mind-inria / hidimstat

HiDimStat: High-dimensional statistical inference tool for Python
https://mind-inria.github.io/hidimstat
BSD 2-Clause "Simplified" License
4 stars 4 forks source link

Accelerate CPI by using batch prediction and Numpy array operations instead of for loop #30

Closed jpaillard closed 2 weeks ago

jpaillard commented 4 weeks ago

I think that the loop over permutations in the CPI.predict function could be avoided by creating larger arrays, predicting over a batch and reshaping the output.

Old version:

for _ in range(B):
   X_permuted.shape = NxP
   y_pred_list.append(model.predict(X_permuted))
y_pred_perm = np.array(y_pred_list)
y_pred_perm.shape = B x N

New version:

X_permuted_all.shape = NBxP
y_pred_perm = model.predict(X_permuted_all)
y_pred_perm.reshape(B, N, )

Running a benchmark on my laptop for n_permutations = 100 leads to 30% faster. image

bthirion commented 3 weeks ago

A decent win indeed ! I guess that the arrays are small enough that having a big one does not harm performance.