kyegomez / Jamba

PyTorch Implementation of Jamba: "Jamba: A Hybrid Transformer-Mamba Language Model"
https://discord.gg/7VckQVxvKk
MIT License
135 stars 11 forks source link

[BUG] The example on the README is not working #2

Closed bruAristimunha closed 4 months ago

bruAristimunha commented 7 months ago

Describe the bug

I tried to run the example code just out of curiosity in a clean env and could not reproduce the output.

To Reproduce

Here in the colab too: https://colab.research.google.com/drive/1OnyI7WfXUkqXqscz8QiFUErDHA2kEfm5?usp=sharing

Steps to reproduce the behavior:

  1. Create a fresh conda, pipenv, or colab env.
  2. Go to the README file.
  3. Install jamba
  4. Run the example:
    
    # Import the torch library, which provides tools for machine learning
    import torch

Import the Jamba model from the jamba.model module

from jamba.model import Jamba

Create a tensor of random integers between 0 and 100, with shape (1, 100)

This simulates a batch of tokens that we will pass through the model

x = torch.randint(0, 100, (1, 100))

Initialize the Jamba model with the specified parameters

dim: dimensionality of the input data

depth: number of layers in the model

num_tokens: number of unique tokens in the input data

d_state: dimensionality of the hidden state in the model

d_conv: dimensionality of the convolutional layers in the model

heads: number of attention heads in the model

num_experts: number of expert networks in the model

num_experts_per_token: number of experts used for each token in the input data

model = Jamba( dim=512, depth=6, num_tokens=100, d_state=256, d_conv=128, heads=8, num_experts=8, num_experts_per_token=2, )

Perform a forward pass through the model with the input data

This will return the model's predictions for each token in the input data

output = model(x)

Print the model's predictions

print(output)

5. See error tracker:
```bash
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

[<ipython-input-1-5e7c358db527>](https://localhost:8080/#) in <cell line: 5>()
      3 
      4 # Import the Jamba model from the jamba.model module
----> 5 from jamba.model import Jamba
      6 
      7 # Create a tensor of random integers between 0 and 100, with shape (1, 100)

23 frames

[/usr/local/lib/python3.10/dist-packages/jamba/__init__.py](https://localhost:8080/#) in <module>
----> 1 from jamba.model import JambaBlock, Jamba
      2 
      3 __all__ = ["JambaBlock", "Jamba"]

[/usr/local/lib/python3.10/dist-packages/jamba/model.py](https://localhost:8080/#) in <module>
      1 from torch import Tensor, nn
----> 2 from zeta import MambaBlock
      3 from zeta.nn import FeedForward
      4 from zeta import MultiQueryAttention
      5 from zeta.nn.modules.simple_rmsnorm import SimpleRMSNorm

[/usr/local/lib/python3.10/dist-packages/zeta/__init__.py](https://localhost:8080/#) in <module>
     26 logger.addFilter(f)
     27 
---> 28 from zeta.nn import *
     29 from zeta.models import *
     30 from zeta.utils import *

[/usr/local/lib/python3.10/dist-packages/zeta/nn/__init__.py](https://localhost:8080/#) in <module>
----> 1 from zeta.nn.attention import *
      2 from zeta.nn.embeddings import *
      3 from zeta.nn.modules import *
      4 from zeta.nn.biases import *

[/usr/local/lib/python3.10/dist-packages/zeta/nn/attention/__init__.py](https://localhost:8080/#) in <module>
     12 # from zeta.nn.attention.mgqa import MGQA
     13 # from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention
---> 14 from zeta.nn.attention.mixture_attention import (
     15     MixtureOfAttention,
     16     MixtureOfAutoregressiveAttention,

[/usr/local/lib/python3.10/dist-packages/zeta/nn/attention/mixture_attention.py](https://localhost:8080/#) in <module>
      6 from typing import Tuple, Optional
      7 from einops import rearrange, repeat, reduce
----> 8 from zeta.models.vit import exists
      9 from zeta.structs.transformer import RMSNorm, apply_rotary_pos_emb
     10 

[/usr/local/lib/python3.10/dist-packages/zeta/models/__init__.py](https://localhost:8080/#) in <module>
      1 # Copyright (c) 2022 Agora
      2 # Licensed under The MIT License [see LICENSE for details]
----> 3 from zeta.models.andromeda import Andromeda
      4 from zeta.models.base import BaseModel
      5 from zeta.models.gpt4 import GPT4, GPT4MultiModal

[/usr/local/lib/python3.10/dist-packages/zeta/models/andromeda.py](https://localhost:8080/#) in <module>
      2 from torch.nn import Module
      3 
----> 4 from zeta.structs.auto_regressive_wrapper import AutoregressiveWrapper
      5 from zeta.structs.transformer import (
      6     Decoder,

[/usr/local/lib/python3.10/dist-packages/zeta/structs/__init__.py](https://localhost:8080/#) in <module>
      2 from zeta.structs.encoder_decoder import EncoderDecoder
      3 from zeta.structs.hierarchical_transformer import HierarchicalTransformer
----> 4 from zeta.structs.local_transformer import LocalTransformer
      5 from zeta.structs.parallel_transformer import ParallelTransformerBlock
      6 from zeta.structs.transformer import (

[/usr/local/lib/python3.10/dist-packages/zeta/structs/local_transformer.py](https://localhost:8080/#) in <module>
      6 from zeta.nn.attention.local_attention_mha import LocalMHA
      7 from zeta.nn.biases.dynamic_position_bias import DynamicPositionBias
----> 8 from zeta.nn.modules import feedforward_network
      9 from zeta.utils.main import eval_decorator, exists, top_k
     10 

[/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/__init__.py](https://localhost:8080/#) in <module>
     45 from zeta.nn.modules.s4 import s4d_kernel
     46 from zeta.nn.modules.h3 import H3Layer
---> 47 from zeta.nn.modules.mlp_mixer import MLPMixer
     48 from zeta.nn.modules.leaky_relu import LeakyRELU
     49 from zeta.nn.modules.adaptive_layernorm import AdaptiveLayerNorm

[/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py](https://localhost:8080/#) in <module>
    143     1, 512, 32, 32
    144 )  # Batch size of 1, 512 channels, 32x32 image
--> 145 output = mlp_mixer(example_input)
    146 print(
    147     output.shape

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py](https://localhost:8080/#) in forward(self, x)
    123         x = rearrange(x, "n c h w -> n (h w) c")
    124         for mixer_block in self.mixer_blocks:
--> 125             x = mixer_block(x)
    126         x = self.pred_head_layernorm(x)
    127         x = x.mean(dim=1)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py](https://localhost:8080/#) in forward(self, x)
     61         y = self.norm1(x)
     62         y = rearrange(y, "n c t -> n t c")
---> 63         y = self.tokens_mlp(y)
     64         y = rearrange(y, "n t c -> n c t")
     65         x = x + y

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py](https://localhost:8080/#) in forward(self, x)
     28             torch.Tensor: _description_
     29         """
---> 30         y = self.dense1(x)
     31         y = F.gelu(y)
     32         return self.dense2(y)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py](https://localhost:8080/#) in forward(self, input)
    114 
    115     def forward(self, input: Tensor) -> Tensor:
--> 116         return F.linear(input, self.weight, self.bias)
    117 
    118     def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)

Upvote & Fund

Fund with Polar

github-actions[bot] commented 7 months ago

Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap.

kyegomez commented 7 months ago

@bruAristimunha please update your zetascale:

‘$ pip install -U zetascale'

bruAristimunha commented 7 months ago

Not working:

/usr/local/lib/python3.10/dist-packages/jamba/init.py in ----> 1 from jamba.model import JambaBlock, Jamba 2 3 all = ["JambaBlock", "Jamba"]

/usr/local/lib/python3.10/dist-packages/jamba/model.py in 1 from torch import Tensor, nn ----> 2 from zeta import MambaBlock 3 from zeta.nn import FeedForward 4 from zeta import MultiQueryAttention 5 from zeta.nn.modules.simple_rmsnorm import SimpleRMSNorm

/usr/local/lib/python3.10/dist-packages/zeta/init.py in 26 logger.addFilter(f) 27 ---> 28 from zeta.nn import 29 from zeta.models import 30 from zeta.utils import *

/usr/local/lib/python3.10/dist-packages/zeta/nn/init.py in ----> 1 from zeta.nn.attention import 2 from zeta.nn.embeddings import 3 from zeta.nn.modules import 4 from zeta.nn.biases import

/usr/local/lib/python3.10/dist-packages/zeta/nn/attention/init.py in 12 # from zeta.nn.attention.mgqa import MGQA 13 # from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention ---> 14 from zeta.nn.attention.mixture_attention import ( 15 MixtureOfAttention, 16 MixtureOfAutoregressiveAttention,

/usr/local/lib/python3.10/dist-packages/zeta/nn/attention/mixture_attention.py in 6 from typing import Tuple, Optional 7 from einops import rearrange, repeat, reduce ----> 8 from zeta.models.vit import exists 9 from zeta.structs.transformer import RMSNorm, apply_rotary_pos_emb 10

/usr/local/lib/python3.10/dist-packages/zeta/models/init.py in 1 # Copyright (c) 2022 Agora 2 # Licensed under The MIT License [see LICENSE for details] ----> 3 from zeta.models.andromeda import Andromeda 4 from zeta.models.base import BaseModel 5 from zeta.models.gpt4 import GPT4, GPT4MultiModal

/usr/local/lib/python3.10/dist-packages/zeta/models/andromeda.py in 2 from torch.nn import Module 3 ----> 4 from zeta.structs.auto_regressive_wrapper import AutoregressiveWrapper 5 from zeta.structs.transformer import ( 6 Decoder,

/usr/local/lib/python3.10/dist-packages/zeta/structs/init.py in 2 from zeta.structs.encoder_decoder import EncoderDecoder 3 from zeta.structs.hierarchical_transformer import HierarchicalTransformer ----> 4 from zeta.structs.local_transformer import LocalTransformer 5 from zeta.structs.parallel_transformer import ParallelTransformerBlock 6 from zeta.structs.transformer import (

/usr/local/lib/python3.10/dist-packages/zeta/structs/local_transformer.py in 6 from zeta.nn.attention.local_attention_mha import LocalMHA 7 from zeta.nn.biases.dynamic_position_bias import DynamicPositionBias ----> 8 from zeta.nn.modules import feedforward_network 9 from zeta.utils.main import eval_decorator, exists, top_k 10

/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/init.py in 45 from zeta.nn.modules.s4 import s4d_kernel 46 from zeta.nn.modules.h3 import H3Layer ---> 47 from zeta.nn.modules.mlp_mixer import MLPMixer 48 from zeta.nn.modules.leaky_relu import LeakyRELU 49 from zeta.nn.modules.adaptive_layernorm import AdaptiveLayerNorm

/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py in 143 1, 512, 32, 32 144 ) # Batch size of 1, 512 channels, 32x32 image --> 145 output = mlp_mixer(example_input) 146 print( 147 output.shape

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs) 1512 1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, *kwargs) 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1521 1522 try:

/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py in forward(self, x) 123 x = rearrange(x, "n c h w -> n (h w) c") 124 for mixer_block in self.mixer_blocks: --> 125 x = mixer_block(x) 126 x = self.pred_head_layernorm(x) 127 x = x.mean(dim=1)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs) 1512 1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, *kwargs) 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1521 1522 try:

/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py in forward(self, x) 61 y = self.norm1(x) 62 y = rearrange(y, "n c t -> n t c") ---> 63 y = self.tokens_mlp(y) 64 y = rearrange(y, "n t c -> n c t") 65 x = x + y

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs) 1512 1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, *kwargs) 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1521 1522 try:

/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py in forward(self, x) 28 torch.Tensor: description 29 """ ---> 30 y = self.dense1(x) 31 y = F.gelu(y) 32 return self.dense2(y)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs) 1512 1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, *kwargs) 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1521 1522 try:

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py in forward(self, input) 114 115 def forward(self, input: Tensor) -> Tensor: --> 116 return F.linear(input, self.weight, self.bias) 117 118 def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)

yangyang2023a commented 7 months ago

Not working:

/usr/local/lib/python3.10/dist-packages/jamba/init.py in ----> 1 from jamba.model import JambaBlock, Jamba 2 3 all = ["JambaBlock", "Jamba"]

/usr/local/lib/python3.10/dist-packages/jamba/model.py in 1 from torch import Tensor, nn ----> 2 from zeta import MambaBlock 3 from zeta.nn import FeedForward 4 from zeta import MultiQueryAttention 5 from zeta.nn.modules.simple_rmsnorm import SimpleRMSNorm

/usr/local/lib/python3.10/dist-packages/zeta/init.py in 26 logger.addFilter(f) 27 ---> 28 from zeta.nn import 29 from zeta.models import 30 from zeta.utils import *

/usr/local/lib/python3.10/dist-packages/zeta/nn/init.py in ----> 1 from zeta.nn.attention import 2 from zeta.nn.embeddings import 3 from zeta.nn.modules import 4 from zeta.nn.biases import

/usr/local/lib/python3.10/dist-packages/zeta/nn/attention/init.py in 12 # from zeta.nn.attention.mgqa import MGQA 13 # from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention ---> 14 from zeta.nn.attention.mixture_attention import ( 15 MixtureOfAttention, 16 MixtureOfAutoregressiveAttention,

/usr/local/lib/python3.10/dist-packages/zeta/nn/attention/mixture_attention.py in 6 from typing import Tuple, Optional 7 from einops import rearrange, repeat, reduce ----> 8 from zeta.models.vit import exists 9 from zeta.structs.transformer import RMSNorm, apply_rotary_pos_emb 10

/usr/local/lib/python3.10/dist-packages/zeta/models/init.py in 1 # Copyright (c) 2022 Agora 2 # Licensed under The MIT License [see LICENSE for details] ----> 3 from zeta.models.andromeda import Andromeda 4 from zeta.models.base import BaseModel 5 from zeta.models.gpt4 import GPT4, GPT4MultiModal

/usr/local/lib/python3.10/dist-packages/zeta/models/andromeda.py in 2 from torch.nn import Module 3 ----> 4 from zeta.structs.auto_regressive_wrapper import AutoregressiveWrapper 5 from zeta.structs.transformer import ( 6 Decoder,

/usr/local/lib/python3.10/dist-packages/zeta/structs/init.py in 2 from zeta.structs.encoder_decoder import EncoderDecoder 3 from zeta.structs.hierarchical_transformer import HierarchicalTransformer ----> 4 from zeta.structs.local_transformer import LocalTransformer 5 from zeta.structs.parallel_transformer import ParallelTransformerBlock 6 from zeta.structs.transformer import (

/usr/local/lib/python3.10/dist-packages/zeta/structs/local_transformer.py in 6 from zeta.nn.attention.local_attention_mha import LocalMHA 7 from zeta.nn.biases.dynamic_position_bias import DynamicPositionBias ----> 8 from zeta.nn.modules import feedforward_network 9 from zeta.utils.main import eval_decorator, exists, top_k 10

/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/init.py in 45 from zeta.nn.modules.s4 import s4d_kernel 46 from zeta.nn.modules.h3 import H3Layer ---> 47 from zeta.nn.modules.mlp_mixer import MLPMixer 48 from zeta.nn.modules.leaky_relu import LeakyRELU 49 from zeta.nn.modules.adaptive_layernorm import AdaptiveLayerNorm

/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py in 143 1, 512, 32, 32 144 ) # Batch size of 1, 512 channels, 32x32 image --> 145 output = mlp_mixer(example_input) 146 print( 147 output.shape

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs) 1512 1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, *kwargs) 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1521 1522 try:

/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py in forward(self, x) 123 x = rearrange(x, "n c h w -> n (h w) c") 124 for mixer_block in self.mixer_blocks: --> 125 x = mixer_block(x) 126 x = self.pred_head_layernorm(x) 127 x = x.mean(dim=1)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs) 1512 1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, *kwargs) 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1521 1522 try:

/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py in forward(self, x) 61 y = self.norm1(x) 62 y = rearrange(y, "n c t -> n t c") ---> 63 y = self.tokens_mlp(y) 64 y = rearrange(y, "n t c -> n c t") 65 x = x + y

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs) 1512 1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, *kwargs) 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1521 1522 try:

/usr/local/lib/python3.10/dist-packages/zeta/nn/modules/mlp_mixer.py in forward(self, x) 28 torch.Tensor: description 29 """ ---> 30 y = self.dense1(x) 31 y = F.gelu(y) 32 return self.dense2(y)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs) 1512 1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, *kwargs) 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1521 1522 try:

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py in forward(self, input) 114 115 def forward(self, input: Tensor) -> Tensor: --> 116 return F.linear(input, self.weight, self.bias) 117 118 def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)

i have encounter the same problem, and it does no work when i try to install jamba package in local enviroment

wyd0042 commented 7 months ago

I encountered the same problem: RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)

hitzhangyu commented 6 months ago

The same problem:RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)

github-actions[bot] commented 4 months ago

Stale issue message