Open geoffwoollard opened 5 months ago
class BioemSensor(torch.nn.Module):
"""
Equations 4 and 10 on p. 6 of SI in 10.1016/j.jsb.2013.10.006
observe ~ N*simulate + mu
with flat uniform (flat) prior on N and mu (method="N-mu")
with saddle point approx of lambda (method="saddle-approx")
Invariant to sign of simulated and observed (each can arbitrarily change sign and does not affect loss)
Notes:
-----
numerical issues when reconstruction with empirical data using
N-mu: nans in up and down after 631/4750 iterations, batch size 2
saddle-approx: nans in term1 and term2 at iterations 2333/4750, batch size 2
"""
sigma: torch.Tensor
def __init__(self, image: ImageConfig,
sigma: float,
N_hi: float = 1.0,
N_lo: float = 0.1,
mu_hi: float = +10.0,
mu_lo: float = -10.0,
mask_radius: Optional[float] = None,
method: str = 'saddle-approx'):
super().__init__()
self.register_buffer('sigma', torch.tensor(sigma))
self.register_buffer('N_hi', torch.tensor(N_hi))
self.register_buffer('N_lo', torch.tensor(N_lo))
self.register_buffer('mu_hi', torch.tensor(mu_hi))
self.register_buffer('mu_lo', torch.tensor(mu_lo))
self.mask_radius = mask_radius
self.method = method
if mask_radius is not None:
self.register_buffer(
'mask',
cryonerf.nn.affine.make_circular_mask(
(image.height, image.width), self.mask_radius
)
)
else:
self.mask = None
def likelihood(
self,
simulated: torch.Tensor,
observed: torch.Tensor,
generator: Optional[torch.Generator] = None,
):
scale = torch.where(self.sigma > 0, 0.5 / self.sigma.square(), torch.ones_like(self.sigma))
if self.mask is not None:
observed = observed * self.mask
simulated = simulated * self.mask
eps = torch.finfo(torch.float32).eps
ccc = simulated.pow(2).sum(dim=(-1,-2))
if torch.isclose(ccc,torch.zeros_like(ccc)).any():
print('WARNING: simulator all zeros, so ccc too close to zero. Injecting noise to avoid nans.')
noise_level = (2*scale).sqrt().pow(-1)
noise = noise_level*torch.randn(simulated.shape, generator=generator, device=simulated.device, dtype=simulated.dtype)
simulated = torch.where(ccc.reshape(-1,1,1)==0, simulated + noise, simulated)
ccc = simulated.pow(2).sum(dim=(-1,-2))
co = observed.sum(dim=(-1,-2))
cc = simulated.sum(dim=(-1,-2))
coo = observed.pow(2).sum(dim=(-1,-2))
coc = (observed * simulated).sum(dim=(-1,-2))
n_pix = observed.shape[-1] * observed.shape[-2]
if self.method == 'N-mu':
# TODO: include missing piece
up = (n_pix*(ccc*coo-coc*coc) + 2*co*coc*cc -ccc*co*co -coo*cc*cc)
down = (n_pix*ccc-cc*cc)
up_over_down = torch.where(torch.logical_and(up==0,down==0), 1,up/down) # protect against 0/0
neg_log_prob = scale*up_over_down + 0.5*safe_log(down.clamp(min=eps)) + (2-n_pix)*safe_log(scale*2)# neglect constant factors
assert not neg_log_prob.isnan().any(), 'TODO: numerically stabilize... up={}|down={}'.format(up,down)
elif self.method == 'saddle-approx':
term1 = n_pix*(ccc*coo-coc*coc) + 2*co*coc*cc - ccc*co*co - coo*cc*cc
term2 = (n_pix-2)*(n_pix*ccc-cc*cc)
neg_log_prob = -(1.5-n_pix/2)*safe_log(term1.clamp(min=eps)) -(n_pix/2-2)*safe_log(term2.clamp(min=eps))
assert not neg_log_prob.isnan().any(), 'TODO: numerically stabilize... term1={}|term2={}'.format(term1,term2)
elif self.method == 'N-mu-gaussian-prior-N':
a = -n_pix*scale
a2 = (cc*cc/n_pix-ccc)*scale
b2 = (coc-cc*co/n_pix)*scale
c2 = (co*co/n_pix - coo) * scale
lambda_N = 100
mu_N = 1
a3 = -1/(2*lambda_N*lambda_N)
b3 = mu_N / (lambda_N*lambda_N)
c3 = -mu_N*mu_N/(2*lambda_N*lambda_N)
neg_log_prob = 0.5*safe_log(-a2-a3) + 0.5*safe_log(-a) + (b2+b3)**2/(4*(a2+a3)) - (c2+c3) + math.log(lambda_N)
else:
raise NotImplementedError("choose a method")
# ad hoc prior for std near 1
do_prior = False
if do_prior:
# std = simulated.std(dim=(-1,-2))
beta = 0
neg_log_prob_prior = (ccc.sqrt() - 1).pow(2) #(simulated.std(dim=(-1,-2)) - 1).pow(2)
neg_log_prob += beta*n_pix*neg_log_prob_prior
neg_log_prob /= n_pix
likelihood_scale = simulated.new_tensor(n_pix)
# return log_prob, {'likelihood_scale': likelihood_scale}
return neg_log_prob, {'likelihood_scale': likelihood_scale, 'neg_log_prob': neg_log_prob}
def sample(self, simulated: torch.Tensor, generator: Optional[torch.Generator] = None):
N = self.N_lo + (self.N_hi - self.N_lo)*torch.rand(simulated.shape[0], generator=generator, device=simulated.device, dtype=simulated.dtype).reshape(-1,1,1)
mu =self.mu_lo + (self.mu_hi - self.mu_lo)*torch.rand(simulated.shape[0], generator=generator, device=simulated.device, dtype=simulated.dtype).reshape(-1,1,1)
noise = torch.randn(
simulated.shape, generator=generator, device=simulated.device, dtype=simulated.dtype
)
return N*simulated + noise.mul_(self.sigma) + mu, {}
def forward(
self,
shot_info: Dict[str, torch.Tensor],
simulated: torch.Tensor,
observed: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
if observed is None:
return self.sample(simulated, generator=generator)
else:
return self.likelihood(simulated, observed, generator=generator)
Check out the cryojax.inference.distributions.AbstractMarginalDistribution
. This was my idea for implementing this. The caveat is that it will require an implementation of the unmarginalized likelihood as well (which I think would be a good idea, at least for the case of the BioEM likelihood)
Can use this for a template
https://github.com/mjo22/cryojax/blob/main/src/cryojax/inference/distributions/_gaussian_distributions.py