amazon-science / chronos-forecasting

Chronos: Pretrained (Language) Models for Probabilistic Time Series Forecasting
https://arxiv.org/abs/2403.07815
Apache License 2.0
2.56k stars 289 forks source link

Use efficient implementation of attention #33

Open abdulfatir opened 7 months ago

abdulfatir commented 7 months ago

I am wondering what's the best way to use efficient implementations of attention. PyTorch provides the experimental torch.nn.functional.scaled_dot_product_attention (SDPA) which supports three different implementations, including flash attention. Unfortunately, we cannot use flash attention because it doesn't support arbitrary attention masks yet (something which is critical for Chronos). It's not clear when attention mask support will be added to flash attention (see https://github.com/Dao-AILab/flash-attention/issues/840). Meanwhile, SDPA falls back to another efficient implementation when a mask is provided.

I monkey patched the T5Attention implementation in transformers and here are the results (script below).

Results

TL;DR: SDPA is clearly faster than the implementation in transformers that we're currently using, even without flash attention.

V100 (float32)

Note: V100 doesn't support bfloat16, so SDPA won't work with bf16 because the custom kernels won't exist.

Using transformers (current version):

                         MASE[0.5]  mean_weighted_sum_quantile_loss  inference_time
model                                                                              
amazon/chronos-t5-base    0.907140                         0.029036      108.118132
amazon/chronos-t5-large   0.950026                         0.021954      313.266375
amazon/chronos-t5-mini    0.874078                         0.024838       21.096206
amazon/chronos-t5-small   0.858876                         0.026758       31.885802
amazon/chronos-t5-tiny    1.001285                         0.029381       11.453301

Using SDPA:

                         MASE[0.5]  mean_weighted_sum_quantile_loss  inference_time
model                                                                              
amazon/chronos-t5-base    0.906459                         0.028953       92.497118
amazon/chronos-t5-large   0.943967                         0.021321      278.541993
amazon/chronos-t5-mini    0.867597                         0.026133       17.471496
amazon/chronos-t5-small   0.861364                         0.026423       26.355608
amazon/chronos-t5-tiny    0.983139                         0.028681        9.756106

A100 (float32)

Using transformers (current version):

                         MASE[0.5]  mean_weighted_sum_quantile_loss  inference_time
model                                                                              
amazon/chronos-t5-base    0.907520                         0.029853       76.029036
amazon/chronos-t5-large   0.938383                         0.021884      217.341671
amazon/chronos-t5-mini    0.875678                         0.025812       13.985228
amazon/chronos-t5-small   0.860030                         0.025327       20.903673
amazon/chronos-t5-tiny    0.984638                         0.029327        8.722677

Using SDPA:

                         MASE[0.5]  mean_weighted_sum_quantile_loss  inference_time
model                                                                              
amazon/chronos-t5-base    0.901114                         0.029077       63.078673
amazon/chronos-t5-large   0.944282                         0.022607      185.249409
amazon/chronos-t5-mini    0.870160                         0.026177       11.738740
amazon/chronos-t5-small   0.850184                         0.026167       18.250515
amazon/chronos-t5-tiny    0.975677                         0.029291        8.546939

A100 (bfloat16)

Using transformers (current version):

                         MASE[0.5]  mean_weighted_sum_quantile_loss  inference_time
model                                                                              
amazon/chronos-t5-base    0.903433                         0.026808       52.598027
amazon/chronos-t5-large   0.945507                         0.022141      149.007310
amazon/chronos-t5-mini    0.874791                         0.024425       10.292101
amazon/chronos-t5-small   0.871871                         0.027540       14.947764
amazon/chronos-t5-tiny    0.994311                         0.030779        7.021869

Using SDPA:

                         MASE[0.5]  mean_weighted_sum_quantile_loss  inference_time
model                                                                              
amazon/chronos-t5-base    0.902784                         0.029677       36.885420
amazon/chronos-t5-large   0.938067                         0.020137      134.648429
amazon/chronos-t5-mini    0.867450                         0.025005        5.402657
amazon/chronos-t5-small   0.861055                         0.027413        7.715756
amazon/chronos-t5-tiny    0.979267                         0.029882        5.227138

Script

import timeit

import numpy as np
import pandas as pd
import torch
from gluonts.dataset.repository import get_dataset
from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import SampleForecast
from torch.nn.functional import scaled_dot_product_attention as sdpa
from transformers.models.t5.modeling_t5 import T5Attention

from chronos import ChronosPipeline

def sdpa_forward(
    self,
    hidden_states,
    mask=None,
    key_value_states=None,
    position_bias=None,
    past_key_value=None,
    layer_head_mask=None,
    query_length=None,
    use_cache=False,
    output_attentions=False,
):
    """
    Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
    """
    # Input is (batch_size, seq_length, dim)
    # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
    # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
    batch_size, seq_length = hidden_states.shape[:2]

    real_seq_length = seq_length

    if past_key_value is not None:
        if len(past_key_value) != 2:
            raise ValueError(
                f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
            )
        real_seq_length += (
            past_key_value[0].shape[2] if query_length is None else query_length
        )

    key_length = (
        real_seq_length if key_value_states is None else key_value_states.shape[1]
    )

    def shape(states):
        """projection"""
        return states.view(
            batch_size, -1, self.n_heads, self.key_value_proj_dim
        ).transpose(1, 2)

    def unshape(states):
        """reshape"""
        return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

    def project(hidden_states, proj_layer, key_value_states, past_key_value):
        """projects hidden states correctly to key/query states"""
        if key_value_states is None:
            # self-attn
            # (batch_size, n_heads, seq_length, dim_per_head)
            hidden_states = shape(proj_layer(hidden_states))
        elif past_key_value is None:
            # cross-attn
            # (batch_size, n_heads, seq_length, dim_per_head)
            hidden_states = shape(proj_layer(key_value_states))

        if past_key_value is not None:
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, key_length, dim_per_head)
                hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
            elif past_key_value.shape[2] != key_value_states.shape[1]:
                # checking that the `sequence_length` of the `past_key_value` is the same as
                # the provided `key_value_states` to support prefix tuning
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))
            else:
                # cross-attn
                hidden_states = past_key_value
        return hidden_states

    # get query states
    query_states = shape(
        self.q(hidden_states)
    )  # (batch_size, n_heads, seq_length, dim_per_head)

    # get key/value states
    key_states = project(
        hidden_states,
        self.k,
        key_value_states,
        past_key_value[0] if past_key_value is not None else None,
    )
    value_states = project(
        hidden_states,
        self.v,
        key_value_states,
        past_key_value[1] if past_key_value is not None else None,
    )

    if position_bias is None:
        if not self.has_relative_attention_bias:
            position_bias = torch.zeros(
                (1, self.n_heads, real_seq_length, key_length),
                device=query_states.device,
                dtype=query_states.dtype,
            )
            if self.gradient_checkpointing and self.training:
                position_bias.requires_grad = True
        else:
            position_bias = self.compute_bias(
                real_seq_length, key_length, device=query_states.device
            )

        # if key and values are already calculated
        # we want only the last query position bias
        if past_key_value is not None:
            position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

        if mask is not None:
            position_bias = (
                position_bias + mask
            )  # (batch_size, n_heads, seq_length, key_length)

    if self.pruned_heads:
        mask = torch.ones(position_bias.shape[1])
        mask[list(self.pruned_heads)] = 0
        position_bias_masked = position_bias[:, mask.bool()]
    else:
        position_bias_masked = position_bias

    assert layer_head_mask is None, "Cannot use layer_head_mask when using SDPA kernel"
    assert not output_attentions, "Cannot output attn_weights when using SDPA kernel"
    attn_output = unshape(
        sdpa(
            query_states,
            key_states,
            value_states,
            attn_mask=position_bias_masked,
            dropout_p=self.dropout if self.training else 0.0,
            scale=1.0,
        )
    )
    attn_output = self.o(attn_output)
    present_key_value_state = (
        (key_states, value_states) if (self.is_decoder and use_cache) else None
    )
    outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)

    return outputs

def benchmark_model(
    pipeline: ChronosPipeline,
    gluonts_dataset: str = "m4_hourly",
    batch_size: int = 32,
):
    dataset = get_dataset(gluonts_dataset)
    prediction_length = dataset.metadata.prediction_length
    _, test_template = split(dataset.test, offset=-prediction_length)
    test_data = test_template.generate_instances(prediction_length)
    test_data_input = list(test_data.input)

    start_time = timeit.default_timer()
    forecasts = []
    for idx in range(0, len(test_data_input), batch_size):
        batch = [
            torch.tensor(item["target"])
            for item in test_data_input[idx : idx + batch_size]
        ]
        batch_forecasts = pipeline.predict(batch, prediction_length)
        forecasts.append(batch_forecasts)
    forecasts = torch.cat(forecasts)
    end_time = timeit.default_timer()

    print(f"Inference time: {end_time-start_time:.2f}s")

    results_df = evaluate_forecasts(
        forecasts=[
            SampleForecast(fcst.numpy(), start_date=label["start"])
            for fcst, label in zip(forecasts, test_data.label)
        ],
        test_data=test_data,
        metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
    )
    results_df["inference_time"] = end_time - start_time
    return results_df

if __name__ == "__main__":
    gluonts_dataset = "m4_hourly"
    models = [
        "amazon/chronos-t5-tiny",
        "amazon/chronos-t5-mini",
        "amazon/chronos-t5-small",
        "amazon/chronos-t5-base",
        "amazon/chronos-t5-large",
    ]
    batch_sizes = [64, 32, 32, 8, 4]

    # Comment out the following line to run the regular transformers version
    T5Attention.forward = sdpa_forward  # Monkey patch forward

    results = []
    for model_name, batch_size in zip(models, batch_sizes):
        pipeline = ChronosPipeline.from_pretrained(
            model_name,
            device_map="cuda:0",
            torch_dtype=torch.float32,
        )
        result_df = benchmark_model(
            pipeline, gluonts_dataset=gluonts_dataset, batch_size=batch_size
        )
        result_df["model"] = model_name
        print(result_df)
        results.append(result_df)
    results = pd.concat(results).set_index("model").sort_index()
    print(results)
abdulfatir commented 5 months ago

Hopefully we can have this once https://github.com/huggingface/transformers/pull/30375 is merged.