kyegomez / BitNet

Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch
https://discord.gg/qUtxnK2NMf
MIT License
1.55k stars 143 forks source link

[BUG] BitFeedForward(post_act_ln=False) results in a TypeError #53

Closed Hiromasa-H closed 4 months ago

Hiromasa-H commented 4 months ago

Describe the bug

Instantiating BitFeedForward() with post_act_ln=False will result in TypeError: 'NoneType' object is not callable in the torch.nn module. (Full traceback shown in “To Reproduce” section.)

This is due to the way post_act_ln is handled in bitnet.bit_ffn.py, where a None object will be placed into the nn.Sequential chain when post_act_ln=False:

self.ff = nn.Sequential(
            project_in,
            nn.LayerNorm(inner_dim) if post_act_ln else None,
            nn.Dropout(dropout),
            BitLinear(inner_dim, dim_out, bias=not no_bias, *args, **kwargs),
        )

Which results in:

#output of print(self.ff)
Sequential(
  (0): Sequential(
    (0): BitLinear(in_features=512, out_features=2048, bias=True)
    (1): SiLU()
  )
  (1): None # <- cause of error
  (2): Dropout(p=0.1, inplace=False)
  (3): BitLinear(in_features=2048, out_features=512, bias=True)
)

This will then trigger the aforementioned TypeError during the forward pass for self.ff .

def forward(self, x):
      """
      Forward pass of the BitFeedForward module.

      Args:
          x (torch.Tensor): The input tensor.

      Returns:
          torch.Tensor: The output tensor.
      """
      return self.ff(x)

To Reproduce

Steps to reproduce the behavior:

  1. Create an env with the following configuration:
    CUDA 12.3
    bitnet==0.2.5
    torch==2.3.0
    torchvision==0.18.0
  2. Run bit_ffn.py with post_act_ln set to False . ff = BitFeedForward(512, 512, 4, swish=True, post_act_ln=False, dropout=0.1)
  3. See the following error:
    
    Traceback (most recent call last):
    File "/home/user/code/contributions/BitNet/bit_ffn.py", line 17, in <module>
    y = ff(x)
    File "/home/user/code/bitvit/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
    File "/home/user/code/bitvit/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
    File "/home/user/code/contributions/BitNet/bitnet/bit_ffn.py", line 129, in forward
    return self.ff(x)
    File "/home/user/code/bitvit/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
    File "/home/user/code/bitvit/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
    File "/home/user/code/bitvit/env/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
    TypeError: 'NoneType' object is not callable
## Expected behavior
When instantiating `BitFeedForward()` with `post_act_ln=False`, an `nn.Sequential` chain without any `None` objects should be returned. (Nothing should be added to the `nn.Sequential` chain in place of `nn.LayerNorm(inner_dim)`.)

## Possible Solutions
I can think of two solutions to address this issue. One is to use an if/else statement to create different sequences depending on the value of `post_act_ln` . Another is to filter out `None` objects from the sequence after the existing code snippet. I’ve made a pull request for the former solution.
```python
if post_act_ln:
      self.ff = nn.Sequential(
      project_in,
      nn.LayerNorm(inner_dim),
      nn.Dropout(dropout),
      BitLinear(inner_dim, dim_out, bias=not no_bias, *args, **kwargs),
  )
else:
    self.ff = nn.Sequential(
        project_in,
        nn.Dropout(dropout),
        BitLinear(inner_dim, dim_out, bias=not no_bias, *args, **kwargs),
    )

Upvote & Fund

Fund with Polar

Hiromasa-H commented 4 months ago

Oops, I just realized this is a duplicate of https://github.com/kyegomez/BitNet/issues/43. I will leave the PR open though.

kyegomez commented 4 months ago

@Hiromasa-H just accepted the pull request

Hiromasa-H commented 4 months ago

@kyegomez Thank you!