model = GeneralMixtureModel((Poisson(), Normal())).fit(data)
and a runtime error occurred
File "/home/anaconda3/envs/scanpy/lib/python3.10/site-packages/pomegranate/gmm.py", line 245, in fit
logp = self.summarize(X, sample_weight=sample_weight,
File "/home/anaconda3/envs/scanpy/lib/python3.10/site-packages/pomegranate/gmm.py", line 320, in summarize
d.summarize(X, y[:, i:i+1] * sample_weight)
File "/home/anaconda3/envs/scanpy/lib/python3.10/site-packages/pomegranate/distributions/normal.py", line 264, in summarize
self._xxw_sum += torch.matmul((X * sample_weight).T, X)
RuntimeError: expected scalar type Double but found Float
It seemed that there need to have a datatype conversion before torch.matmul handle the 2 tensors
To Reproduce
I will try to extract data from other scripts I have used if necessary. I am not sure it is necessary for reproduction of this bug.
Describe the bug I use pomegranate v1.0.0
and a runtime error occurred
It seemed that there need to have a datatype conversion before torch.matmul handle the 2 tensors
To Reproduce I will try to extract data from other scripts I have used if necessary. I am not sure it is necessary for reproduction of this bug.