mjo22 / cryojax

Cryo electron microscopy image simulation and analysis built on JAX.
https://mjo22.github.io/cryojax/
GNU Lesser General Public License v2.1
29 stars 9 forks source link

add in bioem marginal distribution into cryojax.inference.distributions #208

Open geoffwoollard opened 5 months ago

geoffwoollard commented 5 months ago

Can use this for a template

https://github.com/mjo22/cryojax/blob/main/src/cryojax/inference/distributions/_gaussian_distributions.py

geoffwoollard commented 5 months ago
class BioemSensor(torch.nn.Module):
    """
    Equations 4 and 10 on p. 6 of SI in 10.1016/j.jsb.2013.10.006
    observe ~ N*simulate + mu
    with flat uniform (flat) prior on N and mu (method="N-mu")
    with saddle point approx of lambda (method="saddle-approx")

    Invariant to sign of simulated and observed (each can arbitrarily change sign and does not affect loss)

    Notes:
    -----
    numerical issues when reconstruction with empirical data using 
        N-mu: nans in up and down after 631/4750 iterations, batch size 2
        saddle-approx: nans in term1 and term2 at iterations 2333/4750, batch size 2
    """

    sigma: torch.Tensor

    def __init__(self, image: ImageConfig, 
                 sigma: float, 
                 N_hi: float = 1.0,
                 N_lo: float = 0.1,
                 mu_hi: float = +10.0,
                 mu_lo: float = -10.0,
                 mask_radius: Optional[float] = None, 
                 method: str = 'saddle-approx'):
        super().__init__()

        self.register_buffer('sigma', torch.tensor(sigma))
        self.register_buffer('N_hi', torch.tensor(N_hi))
        self.register_buffer('N_lo', torch.tensor(N_lo))
        self.register_buffer('mu_hi', torch.tensor(mu_hi))
        self.register_buffer('mu_lo', torch.tensor(mu_lo))

        self.mask_radius = mask_radius
        self.method = method

        if mask_radius is not None:
            self.register_buffer(
                'mask',
                cryonerf.nn.affine.make_circular_mask(
                    (image.height, image.width), self.mask_radius
                )
            )
        else:
            self.mask = None

    def likelihood(
        self,
        simulated: torch.Tensor,
        observed: torch.Tensor,
        generator: Optional[torch.Generator] = None,
    ):
        scale = torch.where(self.sigma > 0, 0.5 / self.sigma.square(), torch.ones_like(self.sigma))

        if self.mask is not None:
            observed = observed * self.mask
            simulated = simulated * self.mask

        eps = torch.finfo(torch.float32).eps

        ccc = simulated.pow(2).sum(dim=(-1,-2))
        if torch.isclose(ccc,torch.zeros_like(ccc)).any():
            print('WARNING: simulator all zeros, so ccc too close to zero. Injecting noise to avoid nans.')
            noise_level = (2*scale).sqrt().pow(-1)
            noise = noise_level*torch.randn(simulated.shape, generator=generator, device=simulated.device, dtype=simulated.dtype)
            simulated = torch.where(ccc.reshape(-1,1,1)==0, simulated + noise, simulated)
            ccc = simulated.pow(2).sum(dim=(-1,-2))

        co = observed.sum(dim=(-1,-2))
        cc = simulated.sum(dim=(-1,-2))
        coo = observed.pow(2).sum(dim=(-1,-2))
        coc = (observed * simulated).sum(dim=(-1,-2))

        n_pix = observed.shape[-1] * observed.shape[-2]

        if self.method == 'N-mu':
            # TODO: include missing piece
            up = (n_pix*(ccc*coo-coc*coc) + 2*co*coc*cc -ccc*co*co -coo*cc*cc)
            down = (n_pix*ccc-cc*cc)
            up_over_down = torch.where(torch.logical_and(up==0,down==0), 1,up/down) # protect against 0/0
            neg_log_prob = scale*up_over_down + 0.5*safe_log(down.clamp(min=eps)) + (2-n_pix)*safe_log(scale*2)# neglect constant factors
            assert not neg_log_prob.isnan().any(), 'TODO: numerically stabilize... up={}|down={}'.format(up,down)

        elif self.method == 'saddle-approx':
            term1 = n_pix*(ccc*coo-coc*coc) + 2*co*coc*cc - ccc*co*co - coo*cc*cc
            term2 = (n_pix-2)*(n_pix*ccc-cc*cc)
            neg_log_prob = -(1.5-n_pix/2)*safe_log(term1.clamp(min=eps)) -(n_pix/2-2)*safe_log(term2.clamp(min=eps))
            assert not neg_log_prob.isnan().any(), 'TODO: numerically stabilize... term1={}|term2={}'.format(term1,term2)

        elif self.method == 'N-mu-gaussian-prior-N':

            a = -n_pix*scale

            a2 = (cc*cc/n_pix-ccc)*scale
            b2 = (coc-cc*co/n_pix)*scale
            c2 = (co*co/n_pix - coo) * scale

            lambda_N = 100
            mu_N = 1
            a3 = -1/(2*lambda_N*lambda_N)
            b3 = mu_N / (lambda_N*lambda_N)
            c3 = -mu_N*mu_N/(2*lambda_N*lambda_N)

            neg_log_prob = 0.5*safe_log(-a2-a3) + 0.5*safe_log(-a) + (b2+b3)**2/(4*(a2+a3)) - (c2+c3) + math.log(lambda_N) 

        else:
            raise NotImplementedError("choose a method")

        # ad hoc prior for std near 1
        do_prior = False
        if do_prior:
            # std = simulated.std(dim=(-1,-2))
            beta = 0
            neg_log_prob_prior = (ccc.sqrt() - 1).pow(2) #(simulated.std(dim=(-1,-2)) - 1).pow(2)
            neg_log_prob += beta*n_pix*neg_log_prob_prior

        neg_log_prob /= n_pix

        likelihood_scale = simulated.new_tensor(n_pix)

        # return log_prob, {'likelihood_scale': likelihood_scale}
        return neg_log_prob, {'likelihood_scale': likelihood_scale, 'neg_log_prob': neg_log_prob}

    def sample(self, simulated: torch.Tensor, generator: Optional[torch.Generator] = None):
        N = self.N_lo + (self.N_hi - self.N_lo)*torch.rand(simulated.shape[0], generator=generator, device=simulated.device, dtype=simulated.dtype).reshape(-1,1,1)
        mu =self.mu_lo + (self.mu_hi - self.mu_lo)*torch.rand(simulated.shape[0], generator=generator, device=simulated.device, dtype=simulated.dtype).reshape(-1,1,1)
        noise = torch.randn(
            simulated.shape, generator=generator, device=simulated.device, dtype=simulated.dtype
        )
        return N*simulated + noise.mul_(self.sigma) + mu, {}

    def forward(
        self,
        shot_info: Dict[str, torch.Tensor],
        simulated: torch.Tensor,
        observed: Optional[torch.Tensor] = None,
        generator: Optional[torch.Generator] = None
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        if observed is None:
            return self.sample(simulated, generator=generator)
        else:
            return self.likelihood(simulated, observed, generator=generator)
geoffwoollard commented 5 months ago
Screen Shot 2024-04-15 at 10 02 32 AM
mjo22 commented 5 months ago

Check out the cryojax.inference.distributions.AbstractMarginalDistribution. This was my idea for implementing this. The caveat is that it will require an implementation of the unmarginalized likelihood as well (which I think would be a good idea, at least for the case of the BioEM likelihood)