dmlc / gluon-nlp

NLP made easy
https://nlp.gluon.ai/
Apache License 2.0
2.56k stars 538 forks source link

sliding window self-attention cell #1395

Open ZiyueHuang opened 3 years ago

ZiyueHuang commented 3 years ago

Description

The AttentionCell for the sliding window self-attention, including the support for multi-headed dilation and the causal attention mode, described in Longformer: The Long-Document Transformer.

cc @sxjscience @szhengac

Checklist

Essentials

Changes

Comments

cc @dmlc/gluon-nlp-team

ZiyueHuang commented 3 years ago

Waiting for https://github.com/apache/incubator-mxnet/pull/19387 to be merged.

github-actions[bot] commented 3 years ago

The documentation website for preview: http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR1395/sw_atten_cell/index.html

sxjscience commented 3 years ago

Is it possible for us to revise the interface to be similar to https://www.deepspeed.ai/tutorials/sparse-attention/?

github-actions[bot] commented 3 years ago

The documentation website for preview: http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR1395/sw_atten_cell/index.html

ZiyueHuang commented 3 years ago

benchmark script


import numpy as np
from numpy.testing import assert_allclose
import mxnet as mx
from gluonnlp.attention_cell import masked_softmax, MultiHeadAttentionCell, MultiHeadSlidingWindowAttentionCell
import time

def test_multi_head_sliding_window_dot_attention_cell():

    def gen_sliding_window_mask_full(batch_size, seq_length, w, symmetric, d):
        """Generate sliding_window attention mask for the full attention matrix ( seq_len^2 ).
        """
        mask_np = np.zeros((batch_size, seq_length, seq_length))
        for i in range(seq_length):
            end = (i + 1 + w * d) if symmetric else (i + 1)
            for j in range(i - w * d, end, d):
                if j >= 0 and j < seq_length:
                    mask_np[:, i, j] = 1
        return mask_np

    def test_selfatten(batch_size, seq_length, num_heads, num_head_units, w, symmetric, d):
        attn_cell = MultiHeadAttentionCell()
        # Generate the data
        ctx = mx.gpu(0)
        #ctx = mx.cpu()
        query = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
        key = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
        value = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
        mask = gen_sliding_window_mask_full(batch_size, seq_length, w, symmetric, d)
        mask = mx.np.array(mask, ctx=ctx, dtype=np.float32)

        query = mx.np.array(query, ctx=ctx, dtype=np.float32)
        key = mx.np.array(key, ctx=ctx, dtype=np.float32)
        value = mx.np.array(value, ctx=ctx, dtype=np.float32)

        query.attach_grad()
        key.attach_grad()
        value.attach_grad()

        mx.npx.waitall()
        tic = time.time()

        with mx.autograd.record():
            out, _ = attn_cell(query, key, value, mask)
            out.backward()

        mx.npx.waitall()
        toc = time.time()

        return (toc - tic)

    def test_swatten(batch_size, seq_length, num_heads, num_head_units, w, symmetric, d):
        sw_attn_cell = MultiHeadSlidingWindowAttentionCell(w, symmetric)
        # Generate the data
        ctx = mx.gpu(0)
        #ctx = mx.cpu()
        query = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
        key = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
        value = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))

        query = mx.np.array(query, ctx=ctx, dtype=np.float32)
        key = mx.np.array(key, ctx=ctx, dtype=np.float32)
        value = mx.np.array(value, ctx=ctx, dtype=np.float32)

        query.attach_grad()
        key.attach_grad()
        value.attach_grad()

        dilation = mx.np.zeros((num_heads,))
        dilation[:] = d
        dilation = mx.np.array(dilation, ctx=ctx, dtype=np.int32)
        valid_length = np.zeros((batch_size,))
        valid_length[:] = seq_length
        valid_length = mx.np.array(valid_length, ctx=ctx, dtype=np.int32)

        mx.npx.waitall()
        tic = time.time()

        with mx.autograd.record():
            sw_out, _ = sw_attn_cell(query, key, value, dilation, valid_length)
            sw_out.backward()

        mx.npx.waitall()
        toc = time.time()

        return (toc - tic)

    num_repeat = 5

    for seq_length in [512, 1024, 2048, 4096]:
        dur = 0.
        w = seq_length//8
        for i in range(num_repeat):
            tmp_dur = test_selfatten(1, seq_length, 12, 64, w, True, 1)
            if i > 1:
                dur += tmp_dur
        dur /= 3.
        print('seq_length={}, w={}, time={:.3f}'.format(seq_length, w, dur))

        dur = 0.
        for i in range(num_repeat):
            tmp_dur = test_swatten(1, seq_length, 12, 64, w, True, 1)
            if i > 1:
                dur += tmp_dur
        dur /= 3.
        print('sliding-window-attention seq_length={}, w={}, time={:.3f}'.format(seq_length, w, dur))

test_multi_head_sliding_window_dot_attention_cell()
sxjscience commented 3 years ago

Is there any update on this PR?

szhengac commented 3 years ago

@sxjscience it seems the error AttributeError: module 'mxnet.ndarray.numpy_extension' has no attribute 'sldwin_atten_score' is due to that the mxnet version is not the latest.

sxjscience commented 3 years ago

Yes, we can merge the master so that we will retrigger the test.

sxjscience commented 3 years ago

Do we have update on this? @ZiyueHuang would you have time to rebase the code?