Instantiating BitFeedForward() with post_act_ln=False will result in TypeError: 'NoneType' object is not callable in the torch.nn module. (Full traceback shown in “To Reproduce” section.)
This is due to the way post_act_ln is handled in bitnet.bit_ffn.py, where a None object will be placed into the nn.Sequential chain when post_act_ln=False:
#output of print(self.ff)
Sequential(
(0): Sequential(
(0): BitLinear(in_features=512, out_features=2048, bias=True)
(1): SiLU()
)
(1): None # <- cause of error
(2): Dropout(p=0.1, inplace=False)
(3): BitLinear(in_features=2048, out_features=512, bias=True)
)
This will then trigger the aforementioned TypeError during the forward pass for self.ff .
def forward(self, x):
"""
Forward pass of the BitFeedForward module.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
return self.ff(x)
To Reproduce
Steps to reproduce the behavior:
Create an env with the following configuration:
CUDA 12.3
bitnet==0.2.5
torch==2.3.0
torchvision==0.18.0
Run bit_ffn.py with post_act_ln set to False .
ff = BitFeedForward(512, 512, 4, swish=True, post_act_ln=False, dropout=0.1)
See the following error:
Traceback (most recent call last):
File "/home/user/code/contributions/BitNet/bit_ffn.py", line 17, in <module>
y = ff(x)
File "/home/user/code/bitvit/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/code/bitvit/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/code/contributions/BitNet/bitnet/bit_ffn.py", line 129, in forward
return self.ff(x)
File "/home/user/code/bitvit/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/code/bitvit/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/code/bitvit/env/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
TypeError: 'NoneType' object is not callable
## Expected behavior
When instantiating `BitFeedForward()` with `post_act_ln=False`, an `nn.Sequential` chain without any `None` objects should be returned. (Nothing should be added to the `nn.Sequential` chain in place of `nn.LayerNorm(inner_dim)`.)
## Possible Solutions
I can think of two solutions to address this issue. One is to use an if/else statement to create different sequences depending on the value of `post_act_ln` . Another is to filter out `None` objects from the sequence after the existing code snippet. I’ve made a pull request for the former solution.
```python
if post_act_ln:
self.ff = nn.Sequential(
project_in,
nn.LayerNorm(inner_dim),
nn.Dropout(dropout),
BitLinear(inner_dim, dim_out, bias=not no_bias, *args, **kwargs),
)
else:
self.ff = nn.Sequential(
project_in,
nn.Dropout(dropout),
BitLinear(inner_dim, dim_out, bias=not no_bias, *args, **kwargs),
)
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
Describe the bug
Instantiating
BitFeedForward()
withpost_act_ln=False
will result inTypeError: 'NoneType' object is not callable
in thetorch.nn
module. (Full traceback shown in “To Reproduce” section.)This is due to the way
post_act_ln
is handled in bitnet.bit_ffn.py, where aNone
object will be placed into thenn.Sequential
chain whenpost_act_ln=False
:Which results in:
This will then trigger the aforementioned
TypeError
during the forward pass forself.ff
.To Reproduce
Steps to reproduce the behavior:
ff = BitFeedForward(512, 512, 4, swish=True, post_act_ln=False, dropout=0.1)
Upvote & Fund