huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
130.78k stars 26.01k forks source link

Add support for ONNX-TensorRT conversion for GPT-J6B (and possible bug in rotary embedding) #15640

Closed tomerip closed 2 years ago

tomerip commented 2 years ago

Who can help

@patil-suraj

Information

Model I am using: GPT-J

The problem arises when using:

Description

I opened this issue for two reasons:

  1. This is not strictly a bug report, rather a change that enables converting this model to ONNX and then parsing it using the current TensorRT ONNX parser.
  2. Possible implementation bug in GPT-J.

Details

  1. When exporting GPT-J to ONNX using the latest version (v4.16.2), one of the ops that is exported is SplitToSequence (along with more Sequence* ops) that is currently not supported in the TensorRT ONNX parser. This is entirely due to just 1 line of code that uses torch.repeat_interleave. (relevant line)

    sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos)

    By replacing lambda t with this:

    lambda t: t.view(-1, 1).repeat(1, 2).view(seq_len, -1)[None, offset : x.shape[1] + offset, None, :]

    we get the exact same output tensors but now exporting to ONNX doesn't include any Sequence* ops, and TensorRT can parse it successfully. The suggested function is even faster, although probably not critical in this huge model (benched only on CPU):

    original: 106 µs ± 20.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    suggested: 32.4 µs ± 6.55 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
  2. I was following the implementation in EleutherAI for rotary positional embeddings and I'm trying to understand if this is a bug or I'm simply missing something (would love an explanation if you can spare the time) but there (EleutherAI) they implement this function (rotary positional embedding) using torch.cat instead of torch.repeat_interleave, as can be seen here.

If I'm not missing something, the EleutherAI version transforms a tensor from

[[1,2,3],
 [4,5,6]]

to

[[1,2,3,1,2,3],
 [4,5,6,4,5,6]]

and HF version (using repeat_interleave):

[[1,2,3],
 [4,5,6]]

to

[[1,1,2,2,3,3],
 [4,4,5,5,6,6]]

Can anyone confirm the current implementation is indeed correct? Because otherwise cat and repeat_interleave are very different, and the rest of the implementation doesn't take it into account.

LysandreJik commented 2 years ago

Maybe also of interest to @lewtun @michaelbenayoun

lewtun commented 2 years ago

Hey @tomerip thank you for this very detailed and informative feedback!

Regarding point (1), I agree that it would be nice to have a work around for the absent SplitToSequence op in TensorRT.

To proceed, I think we'd need to try the following:

  1. Validate that the proposed change doesn't have a negative impact on fine-tuning / vanilla inference (i.e. on model metrics like accuracy etc). GPT-J is a popular model, so we need to be sure that we don't introduce a subtle error.
  2. Implement an ONNX config for this architecture (see here for a guide)
  3. Validate that the proposed change to the modeling file works for the supported features (i.e. tasks) defined by step (2)
  4. [Nice to have] Validate that the proposed change also works for ONNX Runtime. They support the SplitToSequence op, so my guess is that the proposed change would also work (although I haven't checked)

I can help out with steps 2-4, but would like to first know what the impact of step (1) is. Regarding your point (2), I'll defer to @patil-suraj who is the expert on the GPT-J models 😃

tomerip commented 2 years ago

Hi @lewtun, thanks for replying. I agree that the first step should be making sure the change keeps everything the same. Since the the new lambda_t function returns exactly the same outputs, and the underlying ops are just repeats I don't think the grads will be different. In terms of performance, in my original post we can see it's even slightly faster (although probably negligible anyway).

In this short code we can see that it returns the same outputs:

import torch
import random

random.seed(42)

def fixed_pos_embedding(x, seq_dim=1, seq_len=None):  # no changes in this function
    dim = x.shape[-1]
    if seq_len is None:
        seq_len = x.shape[seq_dim]
    inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
    sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float()
    return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)

def test_lambda_t(batch_size, num_heads, seq_len, head_features, num_blocks, offset):

    x1 = torch.randn(batch_size, num_heads, seq_len, head_features)  # (batch, head, seq_length, head_features)
    x2 = torch.randn(batch_size, num_blocks, num_heads, seq_len, head_features)  # (batch, blocks, head, block_length, head_features)

    for x in [x1, x2]:

        sincos = fixed_pos_embedding(x, 1, seq_len=seq_len)

        lambda_old = lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3)
        lambda_new = lambda t: t.view(-1, 1).repeat(1, 2).view(seq_len, -1)[None, offset : x.shape[1] + offset, None, :]

        sin_old, cos_old = map(lambda_old, sincos)
        sin_new, cos_new = map(lambda_new, sincos)

        if ((sin_old == sin_new).all() and (cos_old == cos_new).all()) == False:
            return False

    return True

def test_lambda_t_with_params(num_iter=10):
    for _ in range(num_iter):
        params = {
            'batch_size': random.randint(1, 8),
            'num_heads': random.randint(4, 16),
            'seq_len': random.randint(32, 2048),
            'head_features': random.randint(64, 512),
            'num_blocks': random.randint(4, 12),
            'offset': random.randint(0, 4),
        }
        if not test_lambda_t(**params):
            return False
    return True

test_lambda_t_with_params()

Regarding steps 2-3, actually since I hit other issues with TensorRT and GPT-J at the moment, I moved to use DeepSpeed (that currently also has some other problems with GPT-J :D), so I will probably not pursue the ONNX path soon, although in the future I intend to.

Still if @patil-suraj can comment on point (2) it would be great!

patil-suraj commented 2 years ago

Will look into this and let you know tomorrow.

lewtun commented 2 years ago

Hey @tomerip thanks a lot for the detailed explanation and for showing that the suggested change produces identical outputs!

If @patil-suraj agrees, would you like to open a PR to implement this? I can then take care of the ONNX export :)

tomerip commented 2 years ago

Hi @lewtun, Sounds great, I'd be happy to open a PR.

patil-suraj commented 2 years ago

Hey @tomerip ! Thanks a lot for this, feel free to open a PR!

Can anyone confirm the current implementation is indeed correct

It is correct. In GPTNeoX it is different because it uses [seq, batch, heads, hdim] format for qkv tensors but in transformers it is [batch, seq, heads, hdim]

tomerip commented 2 years ago

Hey @patil-suraj, Should we keep this issue open so @lewtun can take care of the ONNX export? i.e. steps 2-4 in his comment above.

github-actions[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

lewtun commented 2 years ago

The GPT-J export was added by a community member in https://github.com/huggingface/transformers/pull/16274