yoyololicon / pytorch-NMF

A pytorch package for non-negative matrix factorization.
https://pytorch-nmf.readthedocs.io/
MIT License
223 stars 24 forks source link

Solve issue where beta=2.0 didn't work due to floating type mismatch #4

Closed akashpalrecha closed 3 years ago

akashpalrecha commented 3 years ago

Reproduce the error using:

net = NMF(cube.shape, rank=7).cuda()
_, V = net.fit_transform(cube, H=H, update_H=False, beta=2.0, verbose=1, tol=1e-12, max_iter=2000, alpha=0.0, l1_ratio=0.0)

where cube and H are both torch.float64 types. Whenever beta is 2.0, a floating type mismatch occurs at the loss.backward() part of the optimization.

This PR solves that issue.

yoyololicon commented 3 years ago

@akashpalrecha Thanks for the report. If the parameters type mismatch, I would recommend call double() on a module level, not in the __init__ call, as float32 is the default type for PyTorch.

net = NMF(cube.shape, rank=7).double().cuda()