I think that the loop over permutations in the CPI.predict function could be avoided by creating larger arrays, predicting over a batch and reshaping the output.
Old version:
for _ in range(B):
X_permuted.shape = NxP
y_pred_list.append(model.predict(X_permuted))
y_pred_perm = np.array(y_pred_list)
y_pred_perm.shape = B x N
New version:
X_permuted_all.shape = NBxP
y_pred_perm = model.predict(X_permuted_all)
y_pred_perm.reshape(B, N, )
Running a benchmark on my laptop for n_permutations = 100 leads to 30% faster.
I think that the loop over permutations in the
CPI.predict
function could be avoided by creating larger arrays, predicting over a batch and reshaping the output.Old version:
New version:
Running a benchmark on my laptop for n_permutations = 100 leads to 30% faster.