flatironinstitute / Cryo-EM-Heterogeneity-Challenge-1

The Inaugural Flatiron Institute Cryo-EM Heterogeneity Community Challenge
https://osf.io/8h6fz/
MIT License
2 stars 0 forks source link

generalize map to map distance #31

Open geoffwoollard opened 6 days ago

geoffwoollard commented 6 days ago

Have external distance calculations in another repo Connect through plugin

class MapToMapDistance:
    def __init__(self, config):
        self.config = config

    def get_distance(self, map1, map2):
        raise NotImplementedError()

    def get_distance_matrix(self, maps1, maps2):
        vmap_over_map1 = torch.vmap(self.get_distance, in_dims=0, chunk_size=self.config["analysis"]["chunk_size_submission"])
        vmap_over_map2 = torch.vmap(vmap_over_map1, in_dims=1, chunk_size=self.config["analysis"]["chunk_size_gt"])

        return vmap_over_map2(maps1, maps2)

class L2Distance(MapToMapDistance):
    def __init__(self, config):
        super().__init__(config)

    def get_distance(self, map1, map2):
        return torch.norm(map1 - map2)

class CorrelationDistance(MapToMapDistance):
    def __init__(self, config):
        super().__init__(config)

    def get_distance(self, map1, map2):
        return torch.sum(map1 * map2)

class BioEMDistance(MapToMapDistance):
    def __init__(self, config):
        super().__init__(config)

    def get_distance(self, map1, map2):
        '''
        Compute the cost between two maps using the BioEM cost function in 3D.

        Notes
        -----
        See Eq. 10 in 10.1016/j.jsb.2013.10.006

        Parameters
        ----------
        map1 : torch.Tensor
            shape (n_pix,n_pix,n_pix)
        map2 : torch.Tensor
            shape (n_pix,n_pix,n_pix)

        Returns
        -------
        cost : torch.Tensor
            shape (1,)
        '''
        m1, m2 = map1.reshape(-1), map2.reshape(-1)
        co = m1.sum()
        cc = m2.sum()
        coo = m1.pow(2).sum()
        ccc = m2.pow(2).sum()
        coc = (m1*m2).sum()

        N = len(m1)

        t1 = 2*torch.pi*torch.exp(1)
        t2 = (N*(ccc*coo-coc*coc) + 2*co*coc*cc - ccc*co*co - coo*cc*cc)
        t3 = ((N-2)*(N*ccc-cc*cc))

        smallest_float = torch.finfo(m1.dtype).tiny
        log_prob = 0.5*torch.pi + torch.log(t1)*(1-N/2) + t2.clamp(smallest_float).log()*(3/2-N/2) + t3.clamp(smallest_float).log()*(N/2-2)
        cost = -log_prob

        return cost
geoffwoollard commented 6 days ago

Suggestions: Wrapper function receives inputs and function, output has to be a distance matrix and label. Validate this. Give contributor guidelines for how to incorporate new distance (where to modify code) Offload documentation to external repo (do not maintain code of external distance functions in this repo) Precomputed aspects can be handled under the hood of the external distance function