Open ThomasHeap opened 9 months ago
Can you give an example?
On Wed, 10 Jan 2024 at 21:07, thomas @.***> wrote:
If the user forgets to properly name their inputs there should be some kind of error message instead of waiting for some assert or matrix multiplication to fail. Some kind of warning that checks for inputs with dimensions the size of plates and flags it up?
Similarly there should be a warning if P and Q have the same latents with different dimensionality.
— Reply to this email directly, view it on GitHub https://github.com/alan-ppl/alan/issues/8, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABAUSQQG4P6HDKKJ2QLWELLYN37K3AVCNFSM6AAAAABBVN5U3WVHI2DSMVQWIX3LMV43ASLTON2WKOZSGA3TKMRQGY4DQOA . You are receiving this because you are subscribed to this thread.Message ID: @.***>
import torch as t
from alan import Normal, Bernoulli, Plate, BoundPlate, Problem, Data, checkpoint, OptParam, QEMParam
computation_strategy = checkpoint
P = Plate(
mu = Normal(0, 1, sample_shape = t.Size([2])),
p1 = Plate(
theta = Normal('mu', 1),
obs = Bernoulli(logits = lambda theta, x: theta @ x)
)
)
Q = Plate(
mu = Normal(QEMParam(t.zeros((2,))), QEMParam(t.ones((2,)))),
p1 = Plate(
theta = Normal('mu', 1),
obs = Data()
)
)
platesizes = {'p1': 3}
inputs = {'x': t.randn((3, 2))}
P = BoundPlate(P, platesizes, inputs=inputs)
Q = BoundPlate(Q, platesizes)
P_sample = P.sample()
data = {'obs': t.tensor([1., 1., 0.], names=('p1',))}
prob = Problem(P, Q, data)
sample = prob.sample(K=10)
sample.update_qem_params(0.01)
This gives the following error:
Traceback (most recent call last):
File "/home/thomas/Work/alan/examples/examples/test.py", line 26, in <module>
P = BoundPlate(P, platesizes, inputs=inputs)
File "/home/thomas/Work/alan/src/alan/BoundPlate.py", line 226, in __init__
self.sample()
File "/home/thomas/Work/alan/src/alan/BoundPlate.py", line 375, in sample
torchdim_tree_withK, _ = self._sample(1, False, PermutationSampler, all_platedims)
File "/home/thomas/Work/alan/src/alan/BoundPlate.py", line 355, in _sample
sample = self.plate.sample(
File "/home/thomas/Work/alan/src/alan/Plate.py", line 129, in sample
platesample = prog.sample(
File "/home/thomas/Work/alan/src/alan/Plate.py", line 114, in sample
childsample = sample_gdt(
File "/home/thomas/Work/alan/src/alan/dist.py", line 67, in sample_gdt
sample = dist.sample(scope, reparam, active_platedims, K_dim, timeseries_perm)
File "/home/thomas/Work/alan/src/alan/dist.py", line 300, in sample
return self.tdd(scope).sample(
File "/home/thomas/Work/alan/src/alan/dist.py", line 227, in tdd
return TorchDimDist(self.dist, **self.paramname2val(scope))
File "/home/thomas/Work/alan/src/alan/dist.py", line 217, in paramname2val
val = func(*[scope[arg] for arg in function_arguments(func)])
File "/home/thomas/Work/alan/examples/examples/test.py", line 10, in <lambda>
obs = Bernoulli(logits = lambda theta, x: theta @ x)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (3x2 and 3x2)
So just going from this error I might change:
obs = Bernoulli(logits = lambda theta, x: theta @ x)
to
obs = Bernoulli(logits = lambda theta, x: theta @ x.t())
which results in the following error:
Traceback (most recent call last):
File "/home/thomas/Work/alan/examples/examples/test.py", line 35, in <module>
sample.update_qem_params(0.01)
File "/home/thomas/Work/alan/src/alan/Sample.py", line 352, in update_qem_params
self.problem.Q._update_qem_params(lr, self, computation_strategy=computation_strategy)
File "/home/thomas/Work/alan/src/alan/BoundPlate.py", line 298, in _update_qem_params
self._update_qem_moving_avg(lr, sample, computation_strategy)
File "/home/thomas/Work/alan/src/alan/BoundPlate.py", line 286, in _update_qem_moving_avg
new_moment_list = sample.moments(rmkey_list, computation_strategy=computation_strategy)
File "/home/thomas/Work/alan/src/alan/moments.py", line 162, in named_moments_mixin
result = self._moments_uniform_input(moms, **kwargs)
File "/home/thomas/Work/alan/src/alan/Sample.py", line 339, in _moments_uniform_input
L = self._elbo(self.detached_sample, extra_log_factors=f_J_torchdim_dict, computation_strategy=computation_strategy)
File "/home/thomas/Work/alan/src/alan/Sample.py", line 92, in _elbo
lp, _, _, _ = logPQ_plate(
File "/home/thomas/Work/alan/src/alan/logpq.py", line 45, in logPQ_plate
lpq = lpq_func(
File "/home/thomas/Work/alan/src/alan/logpq.py", line 103, in _logPQ_plate
lps, all_Ks, K_currs, K_inits = lp_getter(
File "/home/thomas/Work/alan/src/alan/logpq.py", line 296, in lp_getter
lp, _Knon_timeseries, _Ktimeseries, _Kinits = method(
File "/home/thomas/Work/alan/src/alan/logpq.py", line 45, in logPQ_plate
lpq = lpq_func(
File "/home/thomas/Work/alan/src/alan/logpq.py", line 103, in _logPQ_plate
lps, all_Ks, K_currs, K_inits = lp_getter(
File "/home/thomas/Work/alan/src/alan/logpq.py", line 296, in lp_getter
lp, _Knon_timeseries, _Ktimeseries, _Kinits = method(
File "/home/thomas/Work/alan/src/alan/logpq.py", line 193, in logPQ_gdt
lp, _ = prog_P[k].log_prob(data[k], scope, None, None)
File "/home/thomas/Work/alan/src/alan/dist.py", line 297, in log_prob
return self.tdd(scope).log_prob(sample), None
File "/home/thomas/Work/alan/src/alan/TorchDimDist.py", line 154, in log_prob
x_tensor = ultimate_order(x, x_dims)
File "/home/thomas/Work/alan/src/alan/utils.py", line 317, in ultimate_order
assert generic_ndim(x) == sum(dim == slice(None) for dim in dims)
AssertionError
Which doesn't tell me what I've done wrong.
For having latents of different sizes:
import torch as t
from alan import Normal, Bernoulli, Plate, BoundPlate, Problem, Data, checkpoint, OptParam, QEMParam
computation_strategy = checkpoint
P = Plate(
mu = Normal(0, 1, sample_shape = t.Size([2])),
p1 = Plate(
theta = Normal('mu', 1, sample_shape = t.Size([2])),
obs = Bernoulli(logits = lambda theta, x: theta @ x)
)
)
Q = Plate(
mu = Normal(QEMParam(t.zeros((2,))), QEMParam(t.ones((2,)))),
p1 = Plate(
theta = Normal('mu', 1),
obs = Data()
)
)
platesizes = {'p1': 3}
inputs = {'x': t.randn((3, 2)).rename('p1', ...)}
P = BoundPlate(P, platesizes, inputs=inputs)
Q = BoundPlate(Q, platesizes)
P_sample = P.sample()
data = {'obs': t.tensor([1., 1., 0.], names=('p1',))}
prob = Problem(P, Q, data)
sample = prob.sample(K=10)
sample.update_qem_params(0.01)
This doesn't throw an error, but should it?
If the user forgets to properly name their inputs there should be some kind of error message instead of waiting for some assert or matrix multiplication to fail. Some kind of warning that checks for inputs with dimensions the size of plates and flags it up?
Similarly there should be a warning if P and Q have the same latents with different dimensionality.