Open LysSanzMoreta opened 1 year ago
@fritzo I am particularly puzzled about this _Why when i) using the context manager poutine.mask it shows under the "mask" in the trace the used mask tensor and when ii) using .mask() or obsmask the "mask" of the trace shows "None".
I am getting different results when using something like this:
mask_t = torch.tensor([True,True])
logits = torch.tensor([3.,4.])
Case A: The trace shows the tensor mask_t under "mask" and 'fn': Independent(Categorical(logits: torch.Size([2])), 1)
with pyro.poutine.mask(mask=mask_t):
pyro.sample("c",dist.Categorical(logits=logits).to_event(1))
Case B: The trace shows None under "mask" and then 'fn': Independent(MaskedDistribution(), 1)
pyro.sample("c",dist.Categorical(logits=logits).mask(mask_t).to_event(1))
Is this expected behaviour? Shouldn't there be 2 different mask types?
Thanks :)
IIRC using Distribution.mask()
stores the mask internally to the distribution, rather than in the trace; you should be able to see this with trace.nodes[name]["fn"]
being a MaskedDistribution
. By contrast poutine.mask()
preserves the original distribution in the "fn" slot and stores the mask in the "mask" slot. The reason for the difference is that the trace based version is usually nicer and clearer, but the distribution mask is needed when you want to mask out part of the event shape, since the trace mask must be broadcastable with batch_shape and can have no event_shape.
@fritzo Thanks, that is very nice to know. I was aware of the need of Distribution.mask()
needed whenevent_shape
is required. However, since I could not "find" where the mask values went, I did not know if it was actually working (using the given mask as intended).
This also leads me to the next concern (which lead me to try and find the mask above). Because I get different results when using
mask_t = torch.Tensor([True,True])
logits = torch.Tensor([3.,4.])
targets = torch.tensor([0.,1.])
The mask is all True
values, therefore indicating that all values should be used for the marginal --> Good results
with pyro.poutine.mask(mask=mask_t): #the mask is all True
pyro.sample("c",dist.Categorical(logits=logits).to_event(1),obs=targets)
No poutine mask (therefore all values should be used in the computation?) --> Bad results
pyro.sample("c",dist.Categorical(logits=logits).to_event(1),obs=targets)
I am guessing this has to do with the event shape
, but I do not understand how, since in my head they should be equivalent ... unless without using poutine.mask
the default is False?
Just curious, why dist.Categorical(logits=logits).to_event(1)
does not raise an error? dist.Categorical(logits=logits)
does not have batch shape.
Regarding mask, the rule of thumb is mask only applies to batch dimensions. Assume you have some univariate distributions and a mask with shape (3,)
, d.expand([3]).mask(mask).to_event(1)
is different from d.expand([3]).to_event(1).mask(mask)
. The former has no batch dimension and event dimension (3,). The later has batch dimension (3,) due to the last mask operator (which has shape (3,)).
d.expand([3]) --> batch_shape: (3,), event_shape: ()
d.expand([3]).mask(mask) --> batch_shape: (3,), event_shape: ()
d.expand([3]).mask(mask).to_event(1) --> batch_shape: (), event_shape: (3,)
d.expand([3]).to_event(1) --> batch_shape: (), event_shape: (3,)
d.expand([3]).to_event(1).mask(mask) -> batch_shape: (3,), event_shape: (3,)
@fehiepsi "Just curious, why dist.Categorical(logits=logits).to_event(1)
does not raise an error? dist.Categorical(logits=logits)
does not have batch shape."
Well that is actually a relief to hear, because the .to_event(1) is doing something (in combination with the poutine.mask), but not sure what. And I did not expect that to happen (I am not familiar enough with which distributions have batch or event shape though). I have Pyro 1.8.2
.
Yes, I figured that the order d.expand([3]).mask(mask).to_event(1)
vs d.expand([3]).to_event(1).mask(mask).
is important. But this is definitely a confusing factor.
And I now understand that poutine.mask
applies over the batch dimensions only, but it is still hard to know when to use which masking method.
poutine.mask
and Distribution.mask
would have the same role: masking log probabilities of a distribution. log_prob
of a distribution will have shape batch_shape
; when masked, its value is log_prob * mask
(note that broadcasting rule applies for such multiplication).
obs_mask
is used for partial observed data (we have a separate feature request for its tutorial #1676).
because the .to_event(1) is doing something (in combination with the poutine.mask), but not sure what
looking at your code,
mask_t = torch.tensor([True,True])
logits = torch.tensor([3.,4.])
,
with pyro.poutine.mask(mask=mask_t):
pyro.sample("c",dist.Categorical(logits=logits).to_event(1))
will raise an error, see this line. If it works for your code, then please raise a separate issue with small reproducible code.
@fehiepsi annotated, looking into making a reproducible code, brb
@fehiepsi Nevermind, it did not raise an error because I hadpyro.enable_validation(False)
It is still weird that it gives good results when pyro.enable_validation(False)
and using
with pyro.poutine.mask(mask=mask_t):
pyro.sample("c",dist.Categorical(logits=logits).to_event(1))
However when using
pyro.sample("c",dist.Categorical(logits=logits).to_event(1))
or
with pyro.poutine.mask(mask=mask_t):
pyro.sample("c",dist.Categorical(logits=logits))
The results are random.
I will try to code everything back again with pyro.enable_validation(True)
I think you can't use to_event here:
import torch
import pyro
import pyro.distributions as dist
pyro.enable_validation(False)
logits = torch.tensor([3.,4.])
dist.Categorical(logits=logits).to_event(1)
would raise an error. Maybe you can check the shapes of your logits
again to get better understanding why there is no error in your system. I don't think that it's due to enable_validation
. If your distribution has no event shape, .to_event(1)
will raise an error whether or not we enable validation.
@fehiepsi Oh, ok, interesting, well my logits simply have shape [N, num_classes],
where N
is the number of data points and num_classes
the number of classes. Let me think about it more and the come back.
@fehiepsi Should I open another issue with these examples? The fail when enable_validation
is True
, not otherwise.
import torch
from torch import tensor
from pyro import sample,plate
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import SVI,Trace_ELBO
from pyro.optim import ClippedAdam
import pyro
def model1(x,obs_mask,x_class,class_mask):
"""
:param x: Data [N,L,feat_dim]
:param obs_mask: Data sites to mask [N,L]
:param x_class: Target values [N,]
:param class_mask: Target values mask [N,]
:return:
"""
z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(2))
logits = torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
[[1,2,7],[0,2,1],[2,7,8]]])
aa = sample("x",dist.Categorical(logits= logits),obs=x)
with pyro.poutine.mask(mask=class_mask):
c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1), obs=x_class)
return z,c,aa
def model2(x,obs_mask,x_class,class_mask):
"""
:param x: Data [N,L,feat_dim]
:param obs_mask: Data sites to mask [N,L]
:param x_class: Target values [N,]
:param class_mask: Target values mask [N,]
:return:
"""
z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(2))
logits = torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
[[1,2,7],[0,2,1],[2,7,8]]])
aa = sample("x",dist.Categorical(logits= logits).mask(obs_mask).to_event(1),obs=x)
c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1), obs=x_class)
return z,c
def model3(x,obs_mask,x_class,class_mask):
"""
:param x: Data [N,L,feat_dim]
:param obs_mask: Data sites to mask [N,L]
:param x_class: Target values [N,]
:param class_mask: Target values mask [N,]
:return:
"""
z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(2))
logits = torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
[[1,2,7],[0,2,1],[2,7,8]]])
aa = sample("x",dist.Categorical(logits= logits).to_event(1),obs=x)
c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1), obs=x_class)
return z,c,aa
def model4(x,obs_mask,x_class,class_mask):
"""
:param x: Data [N,L,feat_dim]
:param obs_mask: Data sites to mask [N,L]
:param x_class: Target values [N,]
:param class_mask: Target values mask [N,]
:return:
"""
z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(2))
logits = torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
[[1,2,7],[0,2,1],[2,7,8]]])
aa = sample("x",dist.Categorical(logits= logits),obs=x,obs_mask=obs_mask) #partial observations is what i am looking for here
c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).mask(class_mask), obs=x_class) #in the fully supervised approach no mask here, but in the semi-supervised i would need to mask fully some observations
return z,c,aa
def model5(x,obs_mask,x_class,class_mask):
"""
:param x: Data [N,L,feat_dim]
:param obs_mask: Data sites to mask [N,L]
:param x_class: Target values [N,]
:param class_mask: Target values mask [N,]
:return:
"""
with pyro.plate("plate_batch",dim=-1):
z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(1))
logits = torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
[[1,2,7],[0,2,1],[2,7,8]]])
aa = sample("x",dist.Categorical(logits= logits),obs=x,obs_mask=obs_mask) #partial observations is what i am looking for here
c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).mask(class_mask), obs=x_class)
return z,c,aa
def guide(x,obs_mask,x_class,class_mask):
"""
:param x: Data [N,L,feat_dim]
:param obs_mask: Data sites to mask [N,L]
:param x_class: Target values [N,]
:param class_mask: Target values mask [N,]
"""
z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(2))
return z
if __name__ == "__main__":
pyro.enable_validation(False)
x = tensor([[0,2,1],
[0,1,1]])
obs_mask = tensor([[1,0,0],[1,1,0]],dtype=bool) #Partial observations
x_class = tensor([0,1])
class_mask = tensor([True,False],dtype=bool) #keep/skip some observations
models_dict = {"model1":model1,
"model2":model2,
"model3":model3,
"model4":model4,
"model5":model5,
}
for model in models_dict.keys():
print("Using {}".format(model))
guide_tr = poutine.trace(guide).get_trace(x,obs_mask,x_class,class_mask)
model_tr = poutine.trace(poutine.replay(models_dict[model], trace=guide_tr)).get_trace(x,obs_mask,x_class,class_mask)
monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
print("MC ELBO estimate: {}".format(monte_carlo_elbo))
try:
pyro.clear_param_store()
svi = SVI(models_dict[model],guide,loss=Trace_ELBO(),optim=ClippedAdam(dict()))
svi.step(x,obs_mask,x_class,class_mask)
print("Test passed")
except:
print("Test failed")
By the way, I think I want something like model4
or model5
. I accept suggestions on what to do about them, because I am not sure how to handle partial observations of "x"
in the model, do I have to do something in the guide? What about semisupervised approaches for the "c"
variable?
Your last example is different from the previous one. Now your logits has shape (N, num_classes)
, so your categorical distribution will have shape (N,)
-> to_event
will work.
Should I open another issue
I'm not sure if there is an issue here. to_event
should work with 2 dimensional logits
. Note that with 2D logits, the code
with mask(mask):
sample(..., Categorical(logits).to_event(1), obs=...)
will give you a distribution/log_prob with batch_shape = mask.shape, and event_shape (N,). Hope that this clarifies the semantics of mask
. If your data has shape (N,)
, log_likelihood of the above code will be the same as
(dist.Categorical(logits).log_prob(data).sum(-1) * mask).sum()
which is the same as
dist.Categorical(logits).log_prob(data).sum(-1) * mask.sum()
In other words, you are scaling the log likelihood by a factor mask.sum()
. There is no "masking" applied here. It makes sense that you can get good results by scaling the likelihood.
@fehiepsi Oh, ok , I see where the misunderstanding with the .to_event()
started, because my first example was not a good one, sorry.
I need to have a fresh mind to reflect about the last part. Cause that would mean that I accidentally scaled up the likelihood and therefore made the training more efficient? That is so interesting
Then, I want to do the same with the variable "x" hahaha (but keeping the partial observations)
Hi!
As discussed here https://forum.pyro.ai/t/more-doubts-on-masking-runnable-example/5044/6 and here https://forum.pyro.ai/t/vae-classification/5017/10, things might not be very clear on when and how to use the different masking options. Especially in defining differences in masking usage on the model vs guide. Or masking with enumeration
Thanks! :)