nipreps / dmriprep

dMRIPrep is a robust and easy-to-use pipeline for preprocessing of diverse dMRI data. The transparent workflow dispenses of manual intervention, thereby ensuring the reproducibility of the results.
https://www.nipreps.org/dmriprep
Apache License 2.0
65 stars 24 forks source link

What is the best approach for predicting an output image using the fitted spherical harmonic / tensor models in Dipy? #55

Closed dPys closed 4 years ago

dPys commented 4 years ago

Basically we need to figure out what will replace the following lines to accommodate for other models in Dipy.

This is what I have so far in the refactored version of this particular interface leading up to calculating the actual signal prediction:

        from dipy.core.gradients import gradient_table_from_bvals_bvecs
        pred_vec = self.inputs.bvec_to_predict
        pred_val = self.inputs.bval_to_predict

        # Load the mask image:
        mask_img = nib.load(self.inputs.b0_mask)
        mask_array = mask_img.get_data() > 1e-6
        all_images = self.inputs.aligned_dwi_files

        ras_b_mat = np.genfromtxt(aligned_vectors, delimiter='\t')
        all_bvecs = np.row_stack([np.zeros(3)] + ras_b_mat[:, 0:3].tolist())
        all_bvals = np.array([0.] + ras_b_mat[:, 3].tolist())

        # Which sample points are too close to the one we want to predict?
        training_mask = _nonoverlapping_qspace_samples(
            pred_val, pred_vec, all_bvals, all_bvecs, self.inputs.minimal_q_distance)
        training_indices = np.flatnonzero(training_mask[1:])
        training_image_paths = [self.inputs.b0_median] + [
            all_images[idx] for idx in training_indices]
        training_bvecs = all_bvecs[training_mask]
        training_bvals = all_bvals[training_mask]
        print("Training with %d of %d", training_mask.sum(), len(training_mask))

        # Load training data and fit the model
        training_data = quick_load_images(training_image_paths)

        # Build gradient table object
        training_gtab = gradient_table_from_bvals_bvecs(training_bvals, training_bvecs,
                                             b0_threshold=self.inputs.b0_threshold)

        # Checked shelledness
        if len(np.unique(training_gtab.bvals)) > 2:
            is_shelled = True
        else:
            is_shelled = False     

        if is_shelled:
            from dipy.reconst.shore import ShoreModel
            radial_order = 6
            zeta = 700
            lambdaN = 1e-8
            lambdaL = 1e-8
            estimator = ShoreModel(training_gtab, radial_order=radial_order,
                             zeta=zeta, lambdaN=lambdaN, lambdaL=lambdaL)
            estimator_fit = estimator.fit(training_data, mask=mask_array)
        else:
            from dipy.reconst.dti import TensorModel, fractional_anisotropy, mean_diffusivity
            from dipy.reconst.csdeconv import recursive_response, ConstrainedSphericalDeconvModel
            estimator_ten = TensorModel(training_gtab)
            estimator_ten_fit = estimator_ten.fit(training_data, mask=mask_array)
            FA = fractional_anisotropy(estimator_ten_fit.evals)
            MD = mean_diffusivity(estimator_ten_fit.evals)
            wm_mask = (np.logical_or(FA >= 0.4, (np.logical_and(FA >= 0.15, MD >= 0.0011))))
            response = recursive_response(training_gtab, training_data, mask=wm_mask)
            estimator_csd = ConstrainedSphericalDeconvModel(training_gtab, response, sh_order=6)
            estimator_csd_fit = estimator_csd.fit(training_data, mask=mask_array)
            # weighted mean of csd predicted array and tensor predicted array?

        # Get the vector for the desired coordinate
        prediction_bvecs = np.tile(pred_vec, (10, 1))
        prediction_bvals = np.ones(10) * pred_val
        prediction_bvals[9] = 0  # prevent warning
        prediction_gtab = gradient_table_from_bvals_bvecs(prediction_bvals, prediction_bvecs,
                                             b0_threshold=self.inputs.b0_threshold)

        # # Calculate the signal prediction, reshape to 3D and save
        # prediction_shore = brainsuite_shore_basis(shore_model.radial_order, shore_model.zeta,
        #                                           prediction_gtab, shore_model.tau)
        # prediction_dir = prediction_shore[0]
        # shore_array = estimator_fit._shore_coef[mask_array]
        # output_data = np.zeros(mask_array.shape)
        # output_data[mask_array] = np.dot(pred_array, prediction_dir)

        prediction_file = op.join(
            runtime.cwd,
            "predicted_b%d_%.2f_%.2f_%.2f.nii.gz" % (
                (pred_val,) + tuple(np.round(pred_vec, decimals=2))))
        nib.Nifti1Image(output_data, mask_img.affine, mask_img.header
                       ).to_filename(prediction_file)
        self._results['predicted_image'] = prediction_file

@arokem @oesteban @mattcieslak

Once this part is addressed, we should be very, very close to having an HMC interface for dmriprep.

dPys commented 4 years ago

Figured this out. Predict using estimator on prediction gtab!