I have a Bayesian Net model as a pomegranate object and just need to update the parameters so it seems like I should be able to use model.fit. As a sanity check, I pass 1000 samples I generate into the fit routine (I realize that should not change the parameters significantly since it samples from the model). The samples generated match expected output (based on node & edge structure I define)
samples = model.sample(1000)
model.fit(samples)
This fails with:
"RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype'
( in conditional_categorical.py", line 168, in summarize: self._xw_sum[j].view(-1).scatteradd(0, X_, sample_weight[:,j]) ). When I check the dtypes in conditional_categorical.py I see:
but this fails with:
ValueError: Parameter X dtype must be one of (torch.int32, torch.int64)
My conda env has pomegranate 1.0.3 (also torchegranate 0.5.0) . Not sure if this is more of a PyTorch issue; I have seen this "Pytorch issue: add dtype checks for scatter/gather family of functions" https://github.com/pytorch/pytorch/pull/38646
I am able to get the following simple/trivial test case to work (the only difference with the above example is that above I appear to have some torch.float64 types in the transition probabilities (based on a more complex computation to arrive at them) )
I have a Bayesian Net model as a pomegranate object and just need to update the parameters so it seems like I should be able to use
model.fit
. As a sanity check, I pass 1000 samples I generate into the fit routine (I realize that should not change the parameters significantly since it samples from the model). The samples generated match expected output (based on node & edge structure I define)This fails with: "RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype' ( in conditional_categorical.py", line 168, in summarize: self._xw_sum[j].view(-1).scatteradd(0, X_, sample_weight[:,j]) ). When I check the dtypes in conditional_categorical.py I see:
self.dtype
Out[1]: torch.float32X.dtype
Out[2]: torch.int32I then tried :
but this fails with: ValueError: Parameter X dtype must be one of (torch.int32, torch.int64)
My conda env has pomegranate 1.0.3 (also torchegranate 0.5.0) . Not sure if this is more of a PyTorch issue; I have seen this "Pytorch issue: add dtype checks for scatter/gather family of functions" https://github.com/pytorch/pytorch/pull/38646
I am able to get the following simple/trivial test case to work (the only difference with the above example is that above I appear to have some torch.float64 types in the transition probabilities (based on a more complex computation to arrive at them) )