jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.29k stars 590 forks source link

[BUG]RuntimeError: expected scalar type Double but found Float #1045

Closed wook2014 closed 1 year ago

wook2014 commented 1 year ago

Describe the bug I use pomegranate v1.0.0

model = GeneralMixtureModel((Poisson(), Normal())).fit(data)

and a runtime error occurred

File "/home/anaconda3/envs/scanpy/lib/python3.10/site-packages/pomegranate/gmm.py", line 245, in fit
    logp = self.summarize(X, sample_weight=sample_weight,
  File "/home/anaconda3/envs/scanpy/lib/python3.10/site-packages/pomegranate/gmm.py", line 320, in summarize
d.summarize(X, y[:, i:i+1] * sample_weight)
  File "/home/anaconda3/envs/scanpy/lib/python3.10/site-packages/pomegranate/distributions/normal.py", line 264, in summarize
self._xxw_sum += torch.matmul((X * sample_weight).T, X)
RuntimeError: expected scalar type Double but found Float

It seemed that there need to have a datatype conversion before torch.matmul handle the 2 tensors

To Reproduce I will try to extract data from other scripts I have used if necessary. I am not sure it is necessary for reproduction of this bug.

jmschrei commented 1 year ago

What is the dtype of data? It should be float32

wook2014 commented 1 year ago

What is the dtype of data? It should be float32

I have checked and it is float64 numpy.ndarray

jmschrei commented 1 year ago

Cast it as a float32. data = data.type(torch.float32)