When I used the mamba network, I defined a loss to test backpropagation and found that the calculation was very slow. Setting the len length to 1024 requires a long waiting time. code show as below:
`import torch
import torch.nn as nn
from zeta.nn import MambaBlock
When I used the mamba network, I defined a loss to test backpropagation and found that the calculation was very slow. Setting the len length to 1024 requires a long waiting time. code show as below:
`import torch import torch.nn as nn from zeta.nn import MambaBlock
block = MambaBlock(dim=512, depth=1) x = torch.randn(1, 1024, 512) target = torch.randn(1, 1024, 512) loss_fn = nn.MSELoss()
y = block(x) loss = loss_fn(y, target) loss.backward() print("Output shape:", y.shape) print("Loss value:", loss.item()) `
Upvote & Fund