kyegomez / zeta

Build high-performance AI models with modular building blocks
https://zeta.apac.ai
Apache License 2.0
320 stars 28 forks source link

an issue for the "MambaBlock" in "Zeta.nn" #236

Open chen-gui opened 1 week ago

chen-gui commented 1 week ago

when I was using the "MambaBlock" in "Zeta.nn" on the GPU, there were the following error: "RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!" I am sure that the error is from the "MambaBlock". I guess the tensor operation in the "ssm" module was not placed on CUDA, although I have already placed the entire designed neural network on CUDA.

The complete test code is as follows: import torch.nn as nn import torch from zeta.nn import MambaBlock

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') x = torch.Tensor(1, 2048, 1).to(device) print("x size: {}".format(x.size()))

model = MambaBlock(dim=1, expand=4, depth=1, d_state=8, d_conv=3, bias=True).to(device) print(model)

out = model(x) print("out size: {}".format(out.size()))

The corresponding complete error is as follows: Traceback (most recent call last): File "E:/DAS_denosing/2024_03_06_SEGXiAn_DAS/unsupervised/Mamba.py", line 141, in out = model(x) File "C:\Anaconda3\envs\tensorflow2.3\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "C:\Anaconda3\envs\tensorflow2.3\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl return forward_call(args, *kwargs) File "C:\Anaconda3\envs\tensorflow2.3\lib\site-packages\zeta\nn\modules\simple_mamba.py", line 118, in forward y = self.ssm(x) File "C:\Anaconda3\envs\tensorflow2.3\lib\site-packages\zeta\nn\modules\simple_mamba.py", line 158, in ssm y = self.selective_scan( File "C:\Anaconda3\envs\tensorflow2.3\lib\site-packages\zeta\nn\modules\simple_mamba.py", line 205, in selective_scan x = deltaA[:, :, i] x + deltaB_u[:, :, i] RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

github-actions[bot] commented 1 week ago

Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap.