Jamie-Stirling / RetNet

An implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
MIT License
1.14k stars 99 forks source link

Chunkwise retention giving different output #10

Closed Jamie-Stirling closed 11 months ago

Jamie-Stirling commented 11 months ago

The implementation of chunkwise retention paradigm on the chunkwise-real branch gives different outputs to the other two paradigms.

It appears there may be a mistake in the paper on which the implementation was based, in equation (7). A pull request fixing this and obtaining outputs consistent with the other two paradigms would be greatly appreciated.

This can be reproduced by running `python src/tests.py', with stdout:

FFF
======================================================================
FAIL: test_retnet (__main__.TestRetNet)
verify that the three implementations of RetNet are identical
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jamie/Repos/RetNet/src/tests.py", line 137, in test_retnet
    self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) # fails
AssertionError: False is not true

======================================================================
FAIL: test_multiscale (__main__.TestRetention)
verify that the three implementations of MultiScaleRetention are identical
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jamie/Repos/RetNet/src/tests.py", line 86, in test_multiscale
    self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) # fails
AssertionError: False is not true

======================================================================
FAIL: test_simple (__main__.TestRetention)
verify that the three implementations of SimpleRetention are identical
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jamie/Repos/RetNet/src/tests.py", line 45, in test_simple
    assert torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5) # fails
AssertionError

----------------------------------------------------------------------
Ran 3 tests in 0.098s

FAILED (failures=3)
donglixp commented 11 months ago

@Jamie-Stirling https://github.com/microsoft/unilm/issues/1213

donglixp commented 11 months ago

You could also refer to https://github.com/microsoft/torchscale/commit/bf65397b26469ac9c24d83a9b779b285c1ec640b

Jamie-Stirling commented 11 months ago

@donglixp Thanks so much for your comment, it was critical to solving this issue.

There was also another term that is omitted in the paper in equation (7) but is otherwise present in the torchscale implementation. Please see line 85 of retention.py:

r_i =(K.transpose(-1, -2) @ (V * D[-1].view(1, chunk_size, 1))) + (self.gamma ** chunk_size) * r_i_1

In particular:

D[-1].view(1, chunk_size, 1)
donglixp commented 11 months ago

@donglixp Thanks so much for your comment, it was critical to solving this issue.

There was also another term that is omitted in the paper in equation (7) but is otherwise present in the torchscale implementation. Please see line 85 of retention.py:

r_i =(K.transpose(-1, -2) @ (V * D[-1].view(1, chunk_size, 1))) + (self.gamma ** chunk_size) * r_i_1

In particular:

D[-1].view(1, chunk_size, 1)

Equation(7) of the latest arXiv paper ( https://arxiv.org/pdf/2307.08621v4.pdf ) fixed the issue.