alan-ppl / alan

An efficient, massively parallel probabilistic programming language
8 stars 0 forks source link

Error message when user forgets to name input #8

Open ThomasHeap opened 9 months ago

ThomasHeap commented 9 months ago

If the user forgets to properly name their inputs there should be some kind of error message instead of waiting for some assert or matrix multiplication to fail. Some kind of warning that checks for inputs with dimensions the size of plates and flags it up?

Similarly there should be a warning if P and Q have the same latents with different dimensionality.

LaurenceA commented 9 months ago

Can you give an example?

On Wed, 10 Jan 2024 at 21:07, thomas @.***> wrote:

If the user forgets to properly name their inputs there should be some kind of error message instead of waiting for some assert or matrix multiplication to fail. Some kind of warning that checks for inputs with dimensions the size of plates and flags it up?

Similarly there should be a warning if P and Q have the same latents with different dimensionality.

— Reply to this email directly, view it on GitHub https://github.com/alan-ppl/alan/issues/8, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABAUSQQG4P6HDKKJ2QLWELLYN37K3AVCNFSM6AAAAABBVN5U3WVHI2DSMVQWIX3LMV43ASLTON2WKOZSGA3TKMRQGY4DQOA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

ThomasHeap commented 9 months ago
import torch as t
from alan import Normal, Bernoulli, Plate, BoundPlate, Problem, Data, checkpoint, OptParam, QEMParam

computation_strategy = checkpoint

P = Plate(
    mu = Normal(0, 1, sample_shape = t.Size([2])), 
    p1 = Plate(
        theta = Normal('mu', 1),
        obs = Bernoulli(logits = lambda theta, x: theta @ x)
    )
)

Q = Plate(
    mu = Normal(QEMParam(t.zeros((2,))), QEMParam(t.ones((2,)))),
    p1 = Plate(
        theta = Normal('mu', 1),
        obs = Data()
    )
)

platesizes = {'p1': 3}

inputs = {'x': t.randn((3, 2))}
P = BoundPlate(P, platesizes, inputs=inputs)
Q = BoundPlate(Q, platesizes)

P_sample = P.sample()
data = {'obs': t.tensor([1., 1., 0.], names=('p1',))}

prob = Problem(P, Q, data)

sample = prob.sample(K=10)
sample.update_qem_params(0.01)

This gives the following error:

Traceback (most recent call last):
  File "/home/thomas/Work/alan/examples/examples/test.py", line 26, in <module>
    P = BoundPlate(P, platesizes, inputs=inputs)
  File "/home/thomas/Work/alan/src/alan/BoundPlate.py", line 226, in __init__
    self.sample()
  File "/home/thomas/Work/alan/src/alan/BoundPlate.py", line 375, in sample
    torchdim_tree_withK, _ = self._sample(1, False, PermutationSampler, all_platedims)
  File "/home/thomas/Work/alan/src/alan/BoundPlate.py", line 355, in _sample
    sample = self.plate.sample(
  File "/home/thomas/Work/alan/src/alan/Plate.py", line 129, in sample
    platesample = prog.sample(
  File "/home/thomas/Work/alan/src/alan/Plate.py", line 114, in sample
    childsample = sample_gdt(
  File "/home/thomas/Work/alan/src/alan/dist.py", line 67, in sample_gdt
    sample = dist.sample(scope, reparam, active_platedims, K_dim, timeseries_perm)
  File "/home/thomas/Work/alan/src/alan/dist.py", line 300, in sample
    return self.tdd(scope).sample(
  File "/home/thomas/Work/alan/src/alan/dist.py", line 227, in tdd
    return TorchDimDist(self.dist, **self.paramname2val(scope))
  File "/home/thomas/Work/alan/src/alan/dist.py", line 217, in paramname2val
    val = func(*[scope[arg] for arg in function_arguments(func)])
  File "/home/thomas/Work/alan/examples/examples/test.py", line 10, in <lambda>
    obs = Bernoulli(logits = lambda theta, x: theta @ x)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (3x2 and 3x2)

So just going from this error I might change:

        obs = Bernoulli(logits = lambda theta, x: theta @ x)

to

        obs = Bernoulli(logits = lambda theta, x: theta @ x.t())

which results in the following error:

Traceback (most recent call last):
  File "/home/thomas/Work/alan/examples/examples/test.py", line 35, in <module>
    sample.update_qem_params(0.01)
  File "/home/thomas/Work/alan/src/alan/Sample.py", line 352, in update_qem_params
    self.problem.Q._update_qem_params(lr, self, computation_strategy=computation_strategy)
  File "/home/thomas/Work/alan/src/alan/BoundPlate.py", line 298, in _update_qem_params
    self._update_qem_moving_avg(lr, sample, computation_strategy)
  File "/home/thomas/Work/alan/src/alan/BoundPlate.py", line 286, in _update_qem_moving_avg
    new_moment_list = sample.moments(rmkey_list, computation_strategy=computation_strategy)
  File "/home/thomas/Work/alan/src/alan/moments.py", line 162, in named_moments_mixin
    result = self._moments_uniform_input(moms, **kwargs)
  File "/home/thomas/Work/alan/src/alan/Sample.py", line 339, in _moments_uniform_input
    L = self._elbo(self.detached_sample, extra_log_factors=f_J_torchdim_dict, computation_strategy=computation_strategy)
  File "/home/thomas/Work/alan/src/alan/Sample.py", line 92, in _elbo
    lp, _, _, _ = logPQ_plate(
  File "/home/thomas/Work/alan/src/alan/logpq.py", line 45, in logPQ_plate
    lpq = lpq_func(
  File "/home/thomas/Work/alan/src/alan/logpq.py", line 103, in _logPQ_plate
    lps, all_Ks, K_currs, K_inits = lp_getter(
  File "/home/thomas/Work/alan/src/alan/logpq.py", line 296, in lp_getter
    lp, _Knon_timeseries, _Ktimeseries, _Kinits = method(
  File "/home/thomas/Work/alan/src/alan/logpq.py", line 45, in logPQ_plate
    lpq = lpq_func(
  File "/home/thomas/Work/alan/src/alan/logpq.py", line 103, in _logPQ_plate
    lps, all_Ks, K_currs, K_inits = lp_getter(
  File "/home/thomas/Work/alan/src/alan/logpq.py", line 296, in lp_getter
    lp, _Knon_timeseries, _Ktimeseries, _Kinits = method(
  File "/home/thomas/Work/alan/src/alan/logpq.py", line 193, in logPQ_gdt
    lp, _ = prog_P[k].log_prob(data[k], scope, None, None)
  File "/home/thomas/Work/alan/src/alan/dist.py", line 297, in log_prob
    return self.tdd(scope).log_prob(sample), None
  File "/home/thomas/Work/alan/src/alan/TorchDimDist.py", line 154, in log_prob
    x_tensor = ultimate_order(x, x_dims)
  File "/home/thomas/Work/alan/src/alan/utils.py", line 317, in ultimate_order
    assert generic_ndim(x) == sum(dim == slice(None) for dim in dims)
AssertionError

Which doesn't tell me what I've done wrong.

ThomasHeap commented 9 months ago

For having latents of different sizes:

import torch as t
from alan import Normal, Bernoulli, Plate, BoundPlate, Problem, Data, checkpoint, OptParam, QEMParam

computation_strategy = checkpoint

P = Plate(
    mu = Normal(0, 1, sample_shape = t.Size([2])), 
    p1 = Plate(
        theta = Normal('mu', 1, sample_shape = t.Size([2])),
        obs = Bernoulli(logits = lambda theta, x: theta @ x)
    )
)

Q = Plate(
    mu = Normal(QEMParam(t.zeros((2,))), QEMParam(t.ones((2,)))),
    p1 = Plate(
        theta = Normal('mu', 1),
        obs = Data()
    )
)

platesizes = {'p1': 3}

inputs = {'x': t.randn((3, 2)).rename('p1', ...)}
P = BoundPlate(P, platesizes, inputs=inputs)
Q = BoundPlate(Q, platesizes)

P_sample = P.sample()
data = {'obs': t.tensor([1., 1., 0.], names=('p1',))}

prob = Problem(P, Q, data)

sample = prob.sample(K=10)
sample.update_qem_params(0.01)

This doesn't throw an error, but should it?