kyegomez / zeta

Build high-performance AI models with modular building blocks
https://zeta.apac.ai
Apache License 2.0
320 stars 28 forks source link

[BUG] Why is the backpropagation calculation so slow when I use the mamba network? #220

Open 1325116124 opened 1 month ago

1325116124 commented 1 month ago

When I used the mamba network, I defined a loss to test backpropagation and found that the calculation was very slow. Setting the len length to 1024 requires a long waiting time. code show as below:

`import torch import torch.nn as nn from zeta.nn import MambaBlock

block = MambaBlock(dim=512, depth=1) x = torch.randn(1, 1024, 512) target = torch.randn(1, 1024, 512) loss_fn = nn.MSELoss()

y = block(x) loss = loss_fn(y, target) loss.backward() print("Output shape:", y.shape) print("Loss value:", loss.item()) `

Upvote & Fund

Fund with Polar

github-actions[bot] commented 1 month ago

Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap.

kyegomez commented 3 weeks ago

@1325116124 its using the mamba scan, or SSM, it should be updated soon!