yoyololicon / pytorch-NMF

A pytorch package for non-negative matrix factorization.
https://pytorch-nmf.readthedocs.io/
MIT License
223 stars 24 forks source link

Provide support for arbitrary devices and dtypes #24

Closed mikesha2 closed 2 years ago

mikesha2 commented 2 years ago

Simple changes to BaseComponent to support arbitrary PyTorch devices, including CUDA and Apple Silicon Metal Performance Shaders in PyTorch >= 1.12.

For Apple Silicon, MPS only support operations for torch.float32

Tested on 2021 M1 Max Macbook Pro.

Example:

import torch  
from torchnmf.nmf import NMF  
model = NMF((100, 100), rank=10, device='mps', dtype=torch.float32)  
model.fit(torch.randn(100, 100, dtype=torch.float32).to('mps').abs())  
print(model.W, model.H)
yoyololicon commented 2 years ago

@mikesha2 Thanks for the contribution. Have you tried NMF((100, 100), rank=10).to('mps').float() before? Did it work the same?

mikesha2 commented 2 years ago

Oh I see, it's a different style than I'm used to!