kyegomez / BitNet

Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch
https://discord.gg/qUtxnK2NMf
MIT License
1.69k stars 155 forks source link

[BUG] Tensor size mismatch from train.py #19

Closed richardburleigh closed 8 months ago

richardburleigh commented 8 months ago

Thank you for sharing this incredible work!

I speculate that it's an issue of library versions, but I'm receiving the following error when attempting to run unmodified train.py: RuntimeError: The size of tensor a (1024) must match the size of tensor b (512) at non-singleton dimension 1

Changing the default SEQ_LEN = 1024 to 512 gives the following: RuntimeError: The size of tensor a (513) must match the size of tensor b (512) at non-singleton dimension 1

While a sequence length of 511 says: RuntimeError: The size of tensor a (511) must match the size of tensor b (512) at non-singleton dimension 1

Full error log:

Traceback (most recent call last):
  File "Data/Development/BitNet/train.py", line 86, in <module>
    loss = model(next(train_loader))
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "Data/Development/BitNet/bitnet/at.py", line 82, in forward
    logits = self.net(x_inp, **kwargs)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "Data/Development/BitNet/bitnet/transformer.py", line 52, in forward
    return self.to_logits(x)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File ".local/lib/python3.10/site-packages/zeta/nn/modules/rms_norm.py", line 35, in forward
    return normed * self.scale * self.gamma
RuntimeError: The size of tensor a (1024) must match the size of tensor b (512) at non-singleton dimension 1

Any help would be appreciated!

Upvote & Fund

Fund with Polar

nathanielhudson commented 8 months ago

FWIW, I'm seeing the same error over here, using pytorch 2.2.1

nathanielhudson commented 8 months ago

@richardburleigh, actually I'm seeing something slightly different than you describe - With SEQ_LEN = 1024, the error is line 86 as noted. With SEQ_LEN = 512 the error is actually on line 106 sample = model.generate(inp[None, ...], GENERATE_LENGTH)

I'm at a bit of a loss here though, as inp[None, ...] definitely has the shape [1, 512], and GENERATE_LENGTH is definitely 512... so I'm not sure why the error is RuntimeError: The size of tensor a (513) must match the size of tensor b (512) at non-singleton dimension 1

Changing inp to have the shape [1, 511] does not help, as the error becomes RuntimeError: The size of tensor a (511) must match the size of tensor b (512) at non-singleton dimension 1

nathanielhudson commented 8 months ago

Okay, I've got it (I think). In addition to changing SEQ_LEN = 512, in at.py the line out = torch.cat((out, sample), dim=-1) needs to be changed to out = torch.cat((out[:, :-1], sample), dim=-1).

MichelNivard commented 8 months ago

I got the same error. So for me editing out = lets the training loop run, but there is generation being outputted at all...

kyegomez commented 8 months ago

@nathanielhudson @MichelNivard hanks for the issue. The error was with RMSNorm, it works now. Just got clone and or git pill

MichelNivard commented 8 months ago

I seem to have got it running thanks!

kyegomez commented 8 months ago

@