angelolab / ark-analysis

Integrated pipeline for multiplexed image analysis
https://ark-analysis.readthedocs.io/en/latest/
MIT License
77 stars 26 forks source link

Convert Pixie to AnnData #1174

Open alex-l-kong opened 2 weeks ago

alex-l-kong commented 2 weeks ago

Relevant background

The Pixie pipeline computes and saves out a lot of extraneous files, most of which correspond to a common set of dimensions (num_cells and num_channels). This has been extremely cumbersome for a few reasons:

  1. It's difficult to keep track of the names of all these files. Additionally, if the naming standard for a file changes, this change needs to propagate through the entire pipeline.
  2. More files = more variables to keep track of and pass around
  3. Fallback protection is difficult to implement because several files need to be checked and possibly updated.
  4. Re-running portions of the pipeline is cumbersome because it requires the user to manually delete files. It's much easier to programmatically access and delete a variable in a shared data structure with AnnData.
  5. Users are unable to easily snapshot previous runs on the same dataset with different parameters.

In light of a gradual conversion of the spatial analysis portion of the pipeline to AnnData, it would be best to think about how to utilize AnnData for a Pixie implementation.

Design overview

The implementation for Pixie will be different for pixel- and cell-level clustering.

Pixel-level clustering

Due to the nature of Pixie being trained on a subset, this will require some additional thought.

The full subsetted data can be represented as an AnnData object:

The main challenge with the pixel clustering component is the need to train on a subset, then assign on the full dataset. No matter how efficient AnnData is in loading data, it would be cumbersome to load a full pixel-level dataset into memory.

The weights can still be accessed from uns + the norm_coeffs from var and used to normalize + assign to each FOV's full pixel dataset individually.

The average channel expressions per cluster tables will require some thought, there are two possible places they could go in:

Instead of saving a hacky dict with params to use in cell clustering, we can instead save this Pixie AnnData object and load this into the cell clustering notebook.

Cell-level clustering

Because we train and label the same dataset, AnnData representation is much easier for cell-level data.

As with pixel-level clustering, we need to decide if the average cell_expression_col per cluster + weighted channel average tables should be stored in varm or uns.

NOTE: if the FlowSOM_Python repo works out for us (https://github.com/saeyslab/FlowSOM_Python), we may be able to delegate a lot of the work there.

Integration with metacluster remapping step/visualization

The visualization component has caused us no shortage of issues in the past, in large part because of inconsistent coloring schemes. AnnData can be used to store information related to the metacluster remapping step to save a lot of hassle.

For now, we can simply access the data we need to pass into the remapping GUI from the AnnData objects. Implementing this will likely be a separate PR in itself.

Code mockup

An AnnData object can be created like:

import anndata as ad
anndata_obj = ad.AnnData(feather.read_dataframe(os.path.join("/path/to/training/data.feather"))) # or just pass in the loaded dataframe

For pixel-level training, we'd have to concatenate the data from all the .feather files together. We already do this in the existing Pixie workflow very efficiently.

Here's how the preprocessing function could look like with AnnData:

def preprocess_fov(...):
    ...
    return pixel_mat, pixel_mat_subset # we will NOT be writing out the subsetted data into separate dfs anymore

def create_pixel_matrix(..., pixel_subset_name="pixel_mat_subsetted.h5ad", ...):
    ...
    # initialize a blank AnnData object, or load if an intermediate file already saved
    if os.path.exists(os.path.join(base_dir, pixel_output_dir, pixel_subset_name)):
        pixel_train_data = ad.read_h5ad(os.path.join(base_dir, pixel_output_dir, pixel_subset_name))
    else:
        pixel_train_data = ad.AnnData(pd.DataFrame({chan: 0 for chan in channels}))
    if "channel" not in pixel_train_data.var.columns.values:
        pixel_train_data.var = pd.DataFrame({"channel": channels})

    fovs_sub = ad.AnnData.obs["fov"].unique()
    fov_data = io_utils.list_files(os.path.join(base_dir, data_dir), substrs='.feather')
    # intersect the two fovs lists together (if a FOV appears in one but not the other, regenerate)
    fovs_full = list(set(fovs_sub).intersection(fovs_data))

    # trim the .feather suffix from the fovs in the subsetted directory
    fovs_full = io_utils.remove_file_extensions(fovs_full)

    # define the list of FOVs for preprocessing
    # NOTE: if an existing FOV is already corrupted, future steps will discard it
    fovs_list = list(set(fovs).difference(set(fovs_full)))

    # if there are no FOVs left to preprocess don't run function
    if len(fovs_list) == 0:
        print("There are no more FOVs to preprocess, skipping")
        return

    # check for missing quant data and add to the list of FOVs for processing
    if "quant_dat_all" not in pixel_train_data.uns:
        pixel_train_data.uns["quant_dat_all"] = pd.DataFrame()
    quant_fov_list = pixel_train_data.uns["quant_dat_all"].columns
    quant_missing = list(set(fovs).difference(set(quant_fov_list)))
    fovs_list = list(set(fovs_list).union(set(quant_missing)))

    # if the process is only partially complete, inform the user of restart
    if len(fovs_list) < len(fovs):
        print("Restarting preprocessing from FOV %s, "
              "%d fovs left to process" % (fovs_list[0], len(fovs_list)))

    ...

    if "norm_vals_pre_rownorm" not in pixel_train_data.var.columns.values or set(pixel_train_data.var["norm_vals_pre_rownorm"].columns.values) != set(channels):
        ...
        pixel_train_data.var["norm_vals_pre_rownorm"] = channel_norm_pre_rownorm_df

    if "pixel_thresh_val" not in pixel_train_data.uns or set(pixel_train_data.var["norm_vals_pre_rownorm"].columns.values) != set(channels):
        ...
        pixel_train_data.uns["pixel_thresh_val"] = pixel_thresh_val

    pixel_train_data.write_h5ad(os.path.join(base_dir, pixel_output_dir, pixel_subset_name))

    ...

    if multiprocess:
        # define the multiprocessing context
        with multiprocessing.get_context('spawn').Pool(batch_size) as fov_data_pool:
            # asynchronously generate and save the pixel matrices per FOV
            # NOTE: fov_data_pool should NOT operate on quant_dat since that is a shared resource
            for fov_batch in [fovs_list[i:(i + batch_size)]
                              for i in range(0, len(fovs_list), batch_size)]:
                fov_data_batch = fov_data_pool.map(fov_data_func, fov_batch)

                # compute the 99.9% quantile values for each FOV
                for pixel_mat_data, pixel_mat_subset in fov_data_batch:
                    ...
                    quant_dat_all = quant_dat_all.merge(pixel_train_data.uns["quant_dat_all"], how="outer",
                                                                                    left_index=True, right_index=True)
                    pixel_train_data["X"] = pd.concat(pixel_train_data["X"], pixel_mat_subset[channels])
                    pixel_train_data["obs"] = pd.concat(pixel_train_data["obs"], pixel_train_data.drop(columns=channels))
                    pixel_train_data.write_h5ad(os.path.join(base_dir, pixel_output_dir, pixel_subset_name))

    else:
        for fov in fovs_list:
            pixel_mat_data, pixel_mat_subset = fov_data_func(fov)
            ...
            quant_dat_all = quant_dat_all.merge(quant_dat_fov, how="outer",
                                                                            left_index=True, right_index=True)
            pixel_train_data["X"] = pd.concat(pixel_train_data["X"], pixel_mat_subset)
            pixel_train_data["obs"] = pd.concat(pixel_train_data["obs"], pixel_train_data.drop(columns=channels))
            pixel_train_data.write_h5ad(os.path.join(base_dir, pixel_output_dir, pixel_subset_name))

    # additional features can be set to pixel_train_data

    # persist for fallback purposes
    pixel_train_data.write_h5ad(os.path.join(base_dir, pixel_output_dir, pixel_subset_name))
    return pixel_train_data

For integration with the rest of Pixie, we'll need to rethink the cluster_helpers.py class.

The bulk of functionality derives from the base PixieSOMCluster class, so only this will be included for demonstration:

class PixieSOMCluster(ABC):
    @abstractmethod
    def __init__(self, pixel_train_data: ad.AnnData, train_data_path: str, num_passes: int = 1,
                         xdim: int = 10, ydim: int = 10, lr_start: float = 0.05, lr_end: float = 0.01,
                         seed=42):
        self.pixel_train_data = pixel_train_data
        if "weights" not in pixel_train_data.uns:
            self.pixel_train_data.uns["weights"] = None
        self.train_data_path = train_data_path
        self.columns = columns
        self.num_passes = num_passes
        self.xdim = xdim
        self.ydim = ydim
        self.lr_start = lr_start
        self.lr_end = lr_end
        self.seed = seed

    @abstractmethod
    def normalize_data(self) -> pd.DataFrame:

    def train_som(self, data: pd.DataFrame):
        # pyFlowSOM.som requires data in np.float64, add type cast for safety purposes
        som_weights = som(
            data=data.values.astype(np.float64), xdim=self.xdim, ydim=self.ydim,
            rlen=self.num_passes, alpha_range=(self.lr_start, self.lr_end), seed=self.seed
        )

        # ensure dimensions of weights are flattened
        som_weights = np.reshape(som_weights, (self.xdim * self.ydim, som_weights.shape[-1]))
        self.pixel_train_data.uns["weights"] = pd.DataFrame(som_weights, columns=data.columns.values)
        self.pixel_train_data.write_h5ad(train_data_path)

    def generate_som_clusters(self, external_data: pd.DataFrame,
                              num_parallel_obs: int = 1000000) -> np.ndarray:
        # ensure batch_size passed is valid
        if num_parallel_obs <= 0:
            raise ValueError("num_parallel_obs specified needs to be greater than 0")

        # subset on just the weights columns prior to SOM cluster mapping
        weights_cols = self.pixel_train_data.uns["weights"].columns.values

        # ensure the weights cols are actually contained in external_data
        verify_in_list(
            weights_cols=weights_cols,
            external_data_cols=external_data.columns.values
        )

        # define the batches of cluster labels assigned
        cluster_labels = []

        # work in batches to support large dataframe sizes
        # TODO: possible dynamic computation in order?
        for i in np.arange(0, external_data.shape[0], num_parallel_obs):
            # NOTE: this also orders the columns of external_data_sub the same as self.weights
            cluster_labels.append(map_data_to_nodes(
                self.weights.values.astype(np.float64),
                external_data.loc[
                    i:min(i + num_parallel_obs - 1, external_data.shape[0]), weights_cols
                ].values.astype(np.float64)
            )[0])

        # if no pixels in the image, return empty array
        if not cluster_labels:
            return np.empty(0)
        # concat all the results together and return
        return np.concatenate(cluster_labels)

class PixelSOMCluster(PixieSOMCluster):
    def __init__(self, pixel_train_data: ad.AnnData,
                 num_passes: int = 1, xdim: int = 10, ydim: int = 10,
                 lr_start: float = 0.05, lr_end: float = 0.01, seed=42):
        super().__init__(
            pixel_train_data, num_passes, xdim, ydim, lr_start, lr_end, seed
        )

        # define each SOM cluster seen
        self.som_clusters_seen = set()

    def normalize_data(self, external_data: pd.DataFrame) -> pd.DataFrame:
        # verify norm_data_cols actually contained in external_data
        verify_in_list(
            norm_data_cols=self.pixel_train_data.var["norm_data_post_rownorm"].columns.values,
            external_data_cols=external_data.columns.values
        )

        # ensure columns in norm_data match up with external_data before normalizing
        norm_data_cols = self.pixel_train_data.var["norm_data_post_rownorm"].columns.values
        external_data_norm = external_data.copy()
        external_data_norm[norm_data_cols] = external_data_norm[norm_data_cols].div(
            self.norm_data.iloc[0], axis=1
        )

        return external_data_norm

    def train_som(self, overwrite=False):
        # if overwrite flag set, retrain SOM regardless of state
        if overwrite:
            warnings.warn('Overwrite flag set, retraining SOM')
        # otherwise, do not train SOM if weights already exist and the same markers used to train
        elif self.pixel_train_data.uns["weights"] is not None:
            if set(self.pixel_train_data.uns["weights"].columns.values) == set(self.pixel_train_data.var["channel"].values):
                warnings.warn('Pixel SOM already trained on specified markers')
                return

            # notify the user that different markers specified
            warnings.warn('New markers specified, retraining')

        super().train_som(self.pixel_train_data.X)

    def assign_som_clusters(self, external_data: pd.DataFrame,
                            normalize_data: bool = True,
                            num_parallel_pixels: int = 1000000) -> pd.DataFrame:
        # normalize external_data prior to assignment, if normalize_data set
        external_data_norm = self.normalize_data(external_data) if normalize_data \
            else external_data.copy()
        som_labels = super().generate_som_clusters(
            external_data_norm, num_parallel_obs=num_parallel_pixels
        )

        # assign SOM clusters to external_data
        external_data_norm['pixel_som_cluster'] = som_labels

        # update the total number of SOM clusters
        self.som_clusters_seen.update(list(np.unique(som_labels)))

        return external_data_norm

class CellSOMCluster(PixieSOMCluster):
    def __init__(self, cell_train_data: ad.AnnData, num_passes: int = 1,
                 xdim: int = 10, ydim: int = 10, lr_start: float = 0.05, lr_end: float = 0.01,
                 seed=42, normalize=True):
        super().__init__(
            cell_train_data, num_passes, xdim, ydim, lr_start, lr_end, seed
        )

        # since cell_data is the only dataset, we can just normalize it immediately
        if normalize:
            self.normalize_data()

    def normalize_data(self):
        # this will need to be converted into numpy-compatible format, but the gist is the same
        cell_norm_vals = self.cell_train_data.X.replace(0, np.nan).quantile(q=0.999, axis=0)
        self.cell_train_data = self.cell_train_data.div(cell_norm_vals)

    def train_som(self, overwrite=False):
        # if overwrite flag set, retrain SOM regardless of state
        if overwrite:
            warnings.warn('Overwrite flag set, retraining SOM')

        # otherwise, do not train SOM if weights already exist and the same columns used to train
        elif self.weights is not None:
            if set(self.pixel_train_data.uns["weights"].columns.values) == set(self.pixel_train_data.var["channels"].values):
                warnings.warn('Cell SOM already trained on specified columns')
                return

            # notify the user that different columns specified
            warnings.warn('New columns specified, retraining')

        super().train_som(self.cell_data.X)

    def assign_som_clusters(self, num_parallel_cells=1000000) -> pd.DataFrame:
        # cell_data is already normalized, don't repeat
        som_labels = super().generate_som_clusters(
            self.cell_train_data.X, num_parallel_obs=num_parallel_cells
        )

        # assign SOM clusters to cell_data
        self.cell_train_data.obs["cell_som_cluster"] = som_labels

        return self.cell_train_data

...

class PixieConsensusCluster:
    def __init__(self, cluster_type: str, pixie_data: ad.AnnData, pixie_path: str
                         max_k: int = 20, cap: float = 3):
        # validate the cluster_type provided
        verify_in_list(
            provided_cluster_type=cluster_type,
            supported_cluster_types=['pixel', 'cell']
        )

        self.pixie_data = pixie_data
        self.pixie_path = pixie_path
        self.cluster_type = cluster_type
        self.som_col = '%s_som_cluster' % cluster_type
        self.meta_col = '%s_meta_cluster' % cluster_type
        self.max_k = max_k
        self.cap = cap

        # NOTE: H set to 10 to replicate default 'reps' arg in R ConsensusClusterPlus
        # resample_proportion set to 0.8 to replicate default 'pItem' arg in ConsensusClusterPlus
        self.cc = ConsensusCluster(
            cluster=AgglomerativeClustering,
            L=max_k,
            K=max_k,
            H=10,
            resample_proportion=0.8
        )

        self.mapping = None

    def scale_data(self):
        """z-scores and caps `input_data`.

        Scaling will be done on a per-column basis for all column names specified.
        Capping will truncate the data in the range `[-cap, cap]`.
        """
        # z-score the data
        self.pixie_data.uns["som_agg_data_zscore"] = self.pixie_data.varm["som_agg_data"].apply(zscore) # need to be converted to numpy-compatible format

        # cap the data in the range [-cap, cap]
        self.pixie_data.uns["som_agg_data_zscore_cap"] = self.pixie_data.uns["som_agg_data_zscore_cap"].clip(
            lower=-self.cap, upper=self.cap
        )

    def run_consensus_clustering(self):
        """Fits the meta clustering results using `ConsensusCluster`.
        """
        self.cc.fit(self.pixie_data.varm["som_agg_data_zscore_cap"])
        self.pixie_data.uns["consensus_cluster_fit"] = self.cc
        self.pixie_data.to_h5ad(self.pixie_path)

    def generate_som_to_meta_map(self):
        """Maps each `'{self.cluster_type}_som_cluster'` to the meta cluster
        generated by `ConsensusCluster`.

        Also assigns mapping to `self.mapping` for use in `assign_consensus_labels`.
        """
        self.pixie_data.varm["som_agg_data_zscore_cap"][self.meta_col] = self.cc.predict_data(self.pixie_data.varm["som_agg_data_zscore_cap"])
        self.pixie_data.uns["som_to_meta_mapping"] = self.pixie_data.uns["som_agg_data_zscore_cap"][[self.som_col, self.meta_col]].copy().astype(int)

        self.pixie_data.uns["som_to_meta_mapping"].loc[:, self.meta_col] += 1

    def assign_consensus_labels(self, external_data: pd.DataFrame) -> pd.DataFrame:
        """Takes an external dataset and applies `ConsensusCluster` mapping to it.

        Args:
            external_data (pandas.DataFrame):
                A dataset which contains a `'{self.cluster_type}_som_cluster'` column.

        Returns:
            pandas.DataFrame:
                The `external_data` with a `'{self.cluster_type}_meta_cluster'` column attached.
        """
        external_data[self.meta_col] = external_data[self.som_col].map(
            self.pixie_data.uns["som_to_meta_map"].set_index(self.som_col)[self.meta_col]
        )
        return external_data

We can extrapolate this to the existing workflows to leverage SOM and meta cluster fitting and prediction.

TODO: include meta cluster remapping stage as well as post-norm visualizations

Required inputs

Same as before, these will be programmatically combined into an AnnData object

Output files

Instead of several fragmented files, the goal is to unify all of this into a single AnnData object that gets saved to a single .h5ad file.

Timeline Give a rough estimate for how long you think the project will take. In general, it's better to be too conservative rather than too optimistic.

Estimated date when a fully implemented version will be ready for review:

Early next year.

Estimated date when the finalized project will be merged in:

Early next year.

alex-l-kong commented 2 weeks ago

@jranek @camisowers this is a very rough draft but provides an overview of how an AnnData conversion could look. Depending on how the FlowSOM_Python package works for us, that could encapsulate away a lot of the AnnData representation. Let me know your thoughts, thanks!

alex-l-kong commented 1 week ago

@srivarra

alex-l-kong commented 1 week ago

@cliu72