Closed dddlli closed 1 month ago
Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap.
The same bug:
File "/home/miniconda3/lib/python3.11/site-packages/zeta/nn/modules/simple_mamba.py", line 205, in selective_scan
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
add .to(x.device); but I think it is inefficient
Replace the following code at line 202 in zeta/nn/modules/simple_mamba.py with:
x = torch.zeros((b, d_in, n)).to(next(self.parameters()).device)
Stale issue message
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm
import torch
from zeta.nn import MambaBlock
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
block = MambaBlock(dim=64, depth=1)
x = torch.randn(1, 10, 64).to(device)
y = block(x).to(device)
print(y.shape)