huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.48k stars 26.66k forks source link

[Summary] Regarding memory issue in tests #18525

Open ydshieh opened 2 years ago

ydshieh commented 2 years ago

Description

This is a short summary of the memory issue in our tests

The following tests definitely have memory issues

Some tests are also suspicious, but need more investigations.

Pytest itself will accumulate some memory usage as tests continue to run.

This is just my hypothesis: sometimes I see an increase of a few KB after a sequence of runs without leak.

Possible actions to take

ydshieh commented 2 years ago

TensorFlow hangs if a TF model is forked

This will hangs

import tensorflow as tf
from transformers import TFDistilBertModel, DistilBertConfig
import multiprocessing

config = DistilBertConfig()
config.n_layers = 1
config.n_heads = 2
config.dim = 4
config.hidden_dim = 4

model = TFDistilBertModel(config)

def func(i):

    print(f"func with arg {i}: start")
    inputs = tf.ones(shape=(2, 3), dtype=tf.int32)
    outputs = model(inputs)
    print(f"func with arg {i}: done")
    return outputs

print("start")
with multiprocessing.Pool(processes=1) as pool:
    r = pool.map(func=func, iterable=range(16))

print("all done")
print(len(r))
ydshieh commented 2 years ago

Strange hanging with TensorFlow Probability

Running the test with --forked

python3 -m pytest --forked -n 2 --max-worker-restart=0 --dist=loadfile -s --make-reports=tests_tf tests/models/auto/test_modeling_tf_auto.py | tee tests_output.txt

with tensorflow-probability installed will hang. After uninstalling tensorflow-probability, the tests finish quickly.

(I am not sure what happens with tensorflow-probability here though)


Actually, running the following also hangs:

python3 -m pytest --forked -v test_tf.py

with test_tf.py being

from transformers import TFAutoModelWithLMHead

#import tensorflow_probability as tfp
from transformers.models.tapas.modeling_tf_tapas import TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST

def test_foo():
    model = TFAutoModelWithLMHead.from_pretrained("julien-c/dummy-unknown")
ydshieh commented 2 years ago

--forked hang with Flax tests

Running the following test with --forked will hang

python3 -m pytest --forked -v test_flax.py

with test_flax.py being

def test_flax_foo():

    from transformers import FlaxDistilBertModel, DistilBertConfig
    import numpy as np

    config = DistilBertConfig()
    config.n_layers = 1
    config.n_heads = 2
    config.dim = 4
    config.hidden_dim = 4
    model = FlaxDistilBertModel(config)
ydshieh commented 2 years ago

cc @LysandreJik for reading :-)

ydshieh commented 2 years ago

To ease the debugging process, the code snippet below is a self-contained script for running FlaxBart. The results looks like

(mem_FlaxBartForConditionalGeneration.json, the memory usage in MB)

[
    157772.0,
    823724.0,
    850768.0,
    878004.0,
    905340.0,
    933288.0,
    959816.0,
    986800.0,
    1013596.0,
    1041560.0,
    1067088.0,
    1095960.0,
    1121640.0,
    1149596.0,
    1175144.0,
    1203396.0,
    1228764.0,
    1256536.0,
    1282528.0,
    1309668.0,
    1337724.0,
    1362584.0,
    1390300.0,
    1417172.0,
    1443084.0,
    1471568.0,
    1494896.0,
    1500424.0,
    1512176.0,
    1519920.0,
    1529484.0
]

Here is the code snippet to run test_beam_search_generate. (This removes all unittest elements, and running without pytest)

import copy
import json
import numpy as np
import os
import psutil
import random
import jax.numpy as jnp
from jax import jit

from transformers import BartConfig, FlaxBartModel, FlaxBartForConditionalGeneration, FlaxBartForSequenceClassification, FlaxBartForQuestionAnswering

def ids_tensor(shape, vocab_size, rng=None):
    """Creates a random int32 tensor of the shape within the vocab size."""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    output = np.array(values, dtype=jnp.int32).reshape(shape)

    return output

def random_attention_mask(shape, rng=None):
    attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
    # make sure that at least one token is attended to for each batch
    attn_mask[:, -1] = 1
    return attn_mask

def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = np.zeros_like(input_ids)
    shifted_input_ids[:, 1:] = input_ids[:, :-1]
    shifted_input_ids[:, 0] = decoder_start_token_id

    shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
    return shifted_input_ids

def prepare_bart_inputs_dict(
    config,
    input_ids,
    decoder_input_ids=None,
    attention_mask=None,
    decoder_attention_mask=None,
    head_mask=None,
    decoder_head_mask=None,
    cross_attn_head_mask=None,
):
    if attention_mask is None:
        attention_mask = np.where(input_ids != config.pad_token_id, 1, 0)
    if decoder_attention_mask is None:
        decoder_attention_mask = np.where(decoder_input_ids != config.pad_token_id, 1, 0)
    if head_mask is None:
        head_mask = np.ones((config.encoder_layers, config.encoder_attention_heads))
    if decoder_head_mask is None:
        decoder_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads))
    if cross_attn_head_mask is None:
        cross_attn_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads))
    return {
        "input_ids": input_ids,
        "decoder_input_ids": decoder_input_ids,
        "attention_mask": attention_mask,
        "decoder_attention_mask": attention_mask,
    }

class FlaxBartModelTester:
    def __init__(
        self,
        parent,
        batch_size=13,
        seq_length=7,
        is_training=True,
        use_labels=False,
        vocab_size=99,
        hidden_size=16,
        num_hidden_layers=2,
        num_attention_heads=4,
        intermediate_size=4,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=32,
        eos_token_id=2,
        pad_token_id=1,
        bos_token_id=0,
        initializer_range=0.02,
    ):
        self.parent = parent
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.is_training = is_training
        self.use_labels = use_labels
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.eos_token_id = eos_token_id
        self.pad_token_id = pad_token_id
        self.bos_token_id = bos_token_id
        self.initializer_range = initializer_range

    def prepare_config_and_inputs(self):
        input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size)
        input_ids = np.concatenate((input_ids, 2 * np.ones((self.batch_size, 1), dtype=np.int64)), -1)

        decoder_input_ids = shift_tokens_right(input_ids, 1, 2)

        config = BartConfig(
            vocab_size=self.vocab_size,
            d_model=self.hidden_size,
            encoder_layers=self.num_hidden_layers,
            decoder_layers=self.num_hidden_layers,
            encoder_attention_heads=self.num_attention_heads,
            decoder_attention_heads=self.num_attention_heads,
            encoder_ffn_dim=self.intermediate_size,
            decoder_ffn_dim=self.intermediate_size,
            dropout=self.hidden_dropout_prob,
            attention_dropout=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            eos_token_id=self.eos_token_id,
            bos_token_id=self.bos_token_id,
            pad_token_id=self.pad_token_id,
            initializer_range=self.initializer_range,
            use_cache=False,
        )
        inputs_dict = prepare_bart_inputs_dict(config, input_ids, decoder_input_ids)
        return config, inputs_dict

    def prepare_config_and_inputs_for_common(self):
        config, inputs_dict = self.prepare_config_and_inputs()
        return config, inputs_dict

class FlaxBartModelTest:
    is_encoder_decoder = True

    def __init__(self, model_class):
        self.model_tester = FlaxBartModelTester(self)
        self.model_class = model_class

    def _prepare_for_class(self, inputs_dict, model_class):
        inputs_dict = copy.deepcopy(inputs_dict)

        # hack for now until we have AutoModel classes
        if "ForMultipleChoice" in model_class.__name__:
            inputs_dict = {
                k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
                if isinstance(v, (jnp.ndarray, np.ndarray))
                else v
                for k, v in inputs_dict.items()
            }

        return inputs_dict

    def _get_input_ids_and_config(self):
        config, inputs = self.model_tester.prepare_config_and_inputs_for_common()

        # cut to half length & take max batch_size 3
        max_batch_size = 2
        sequence_length = inputs["input_ids"].shape[-1] // 2
        input_ids = inputs["input_ids"][:max_batch_size, :sequence_length]

        attention_mask = jnp.ones_like(input_ids)
        attention_mask = attention_mask[:max_batch_size, :sequence_length]

        # generate max 5 tokens
        max_length = input_ids.shape[-1] + 5
        if config.eos_token_id is not None and config.pad_token_id is None:
            # hack to allow generate for models such as GPT2 as is done in `generate()`
            config.pad_token_id = config.eos_token_id
        return config, input_ids, attention_mask, max_length

    def test_hidden_states_output(self):
        def check_hidden_states_output(inputs_dict, config, model_class):
            model = model_class(config)
            model_inputs = self._prepare_for_class(inputs_dict, model_class)
            outputs = model(**model_inputs)

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        inputs_dict["output_hidden_states"] = True
        check_hidden_states_output(inputs_dict, config, self.model_class)

        # check that output_hidden_states also work using config
        del inputs_dict["output_hidden_states"]
        config.output_hidden_states = True

        check_hidden_states_output(inputs_dict, config, self.model_class)

    def test_beam_search_generate(self):
        config, input_ids, _, max_length = self._get_input_ids_and_config()
        config.do_sample = False
        config.max_length = max_length
        config.num_beams = 2

        model = self.model_class(config)

        generation_outputs = model.generate(input_ids).sequences
        jit_generate = jit(model.generate)
        jit_generation_outputs = jit_generate(input_ids).sequences

if __name__ == "__main__":

    all_model_classes = (
        (
            # FlaxBartModel,
            FlaxBartForConditionalGeneration,
            # FlaxBartForSequenceClassification,
            # FlaxBartForQuestionAnswering,
        )
    )

    for model_class in all_model_classes:

        test = FlaxBartModelTest(model_class)
        all_rss = []

        p = psutil.Process(os.getpid())
        m = p.memory_full_info()
        rss = m.rss / 1024
        all_rss.append(rss)

        for i in range(30):

            # This is fine
            # test.test_hidden_states_output()

            # Mem. leak
            test.test_beam_search_generate()

            m = p.memory_full_info()
            rss = m.rss / 1024
            all_rss.append(rss)

            fn = f"mem_{model_class.__name__}.json"

            with open(fn, "w") as fp:
                json.dump(all_rss, fp, ensure_ascii=False, indent=4)
LysandreJik commented 2 years ago

Thanks for summarizing all the info, @ydshieh!

ydshieh commented 2 years ago

To debug test_torch_fx more easily:

with n_iter = 500:

import copy
import torch
import tempfile
import os
import json
import pickle
import psutil
import multiprocessing

from transformers.utils.fx import symbolic_trace
from transformers import BartConfig, BartModel

torch_device = "cpu"
model_class = BartModel
config_dict = {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "d_model": 16,
  "decoder_attention_heads": 4,
  "decoder_ffn_dim": 4,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 2,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "encoder_attention_heads": 4,
  "encoder_ffn_dim": 4,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 2,
  "eos_token_id": 2,
  "forced_eos_token_id": None,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": True,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 20,
  "model_type": "bart",
  "num_hidden_layers": 2,
  "pad_token_id": 1,
  "scale_embedding": False,
  "transformers_version": "4.22.0.dev0",
  "use_cache": True,
  "vocab_size": 99
}
config = BartConfig(**config_dict)
inputs = {
    'input_ids': torch.tensor([
        [22, 30, 84, 13, 46, 95,  2],
        [74, 91, 58, 38,  3, 48,  2],
        [43, 32, 21, 60, 12, 42,  2],
        [20, 24, 75, 46, 62, 55,  2],
        [59, 91, 36, 57, 40, 36,  2],
        [23, 24, 33, 70, 13, 93,  2],
        [15,  4, 11, 45,  5, 87,  2],
        [78, 76, 67, 38,  3, 46,  2],
        [ 3, 31, 35, 85, 81, 46,  2],
        [47, 45, 97, 80, 75, 91,  2],
        [92, 49, 42, 65, 74, 98,  2],
        [67, 37, 84, 88, 55, 57,  2],
        [24, 53, 44, 36, 45, 24,  2],
    ], dtype=torch.int32),
    'decoder_input_ids': torch.tensor([
        [50, 56, 84, 91, 16, 49, 54],
        [ 2, 71, 62, 39, 27,  4, 93],
        [73, 45, 61, 63, 35, 25,  7],
        [27, 33, 23, 86, 13, 49, 32],
        [74, 36, 46, 83, 18, 40, 22],
        [45, 69, 41,  3, 29, 56, 49],
        [ 3, 38,  8, 52, 17, 55, 15],
        [63, 79, 42, 64, 62, 39, 40],
        [28, 59, 69, 14, 77, 45, 36],
        [56, 55, 82, 35, 66, 51, 19],
        [18, 96, 43, 34, 16, 69, 94],
        [68, 65, 52, 17, 77, 78, 54],
        [68, 57, 74, 42, 60, 13, 91]
    ]),
    'attention_mask': torch.tensor([
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True]
    ], dtype=torch.bool),
    'decoder_attention_mask': torch.tensor([
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True]
    ], dtype=torch.bool),
    'head_mask': torch.tensor([[1., 1., 1., 1.], [1., 1., 1., 1.]]),
    'decoder_head_mask': torch.tensor([[1., 1., 1., 1.], [1., 1., 1., 1.]]),
    'cross_attn_head_mask': torch.tensor([[1., 1., 1., 1.], [1., 1., 1., 1.]])
}

def _config_zero_init(config):
    configs_no_init = copy.deepcopy(config)
    for key in configs_no_init.__dict__.keys():
        if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
            setattr(configs_no_init, key, 1e-10)
    return configs_no_init

def _run_torch_jit(in_queue, out_queue):

    model, input_names, filtered_inputs = in_queue.get()
    traced_model = symbolic_trace(model, input_names)
    # blocked if forked
    with torch.no_grad():
        traced_output = traced_model(**filtered_inputs)

    # Test that the model can be TorchScripted
    scripted = torch.jit.script(traced_model)
    with torch.no_grad():
        scripted_output = scripted(**filtered_inputs)

    out_queue.put((traced_model, scripted_output))
    out_queue.join()

def create_and_check_torch_fx_tracing(model_class, config, inputs, n_iter=100, with_new_proc=False):

    configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
    configs_no_init.return_dict = False

    model = model_class(config=configs_no_init)
    model.to(torch_device)
    model.eval()

    model.config.use_cache = False
    input_names = [
        "attention_mask",
        "decoder_attention_mask",
        "decoder_input_ids",
        "input_features",
        "input_ids",
        "input_values",
    ]

    filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
    input_names = list(filtered_inputs.keys())

    model_output = model(**filtered_inputs)

    all_rss = []

    p = psutil.Process(os.getpid())
    m = p.memory_full_info()
    rss = m.rss / 1024
    all_rss.append(rss)

    for i in range(n_iter):

        print(f"idx: {i} - start")

        if not with_new_proc:

            traced_model = symbolic_trace(model, input_names)
            with torch.no_grad():
                traced_output = traced_model(**filtered_inputs)

            # Test that the model can be TorchScripted
            scripted = torch.jit.script(traced_model)
            with torch.no_grad():
                scripted_output = scripted(**filtered_inputs)

        else:

            ctx = multiprocessing.get_context('spawn')

            in_queue = ctx.Queue()
            out_queue = ctx.JoinableQueue()

            in_queue.put((model, input_names, filtered_inputs))

            process = ctx.Process(target=_run_torch_jit, args=(in_queue, out_queue))
            process.start()
            traced_model, scripted_output = out_queue.get()
            out_queue.task_done()
            process.join()

        print(f"idx: {i} - end")
        print("=" * 40)

        m = p.memory_full_info()
        rss = m.rss / 1024
        all_rss.append(rss)

        fn = f"torch_jit_script_mem_with_new_proc={with_new_proc}.json"

        with open(fn, "w") as fp:
            json.dump(all_rss, fp, ensure_ascii=False, indent=4)

if __name__ == "__main__":

    create_and_check_torch_fx_tracing(model_class, config, inputs, n_iter=500, with_new_proc=True)
    create_and_check_torch_fx_tracing(model_class, config, inputs, n_iter=500,  with_new_proc=False)
ydshieh commented 2 years ago

@patil-suraj @sanchit-gandhi @patrickvonplaten

We have memory leak issue in some Flax tests. Basically, I observed this happens for test_beam_search_generate, test_beam_search_generate_attn_mask and test_beam_search_generate_logits_warper, but there might be more. Each call to them increase memory usage by 10~30 MB.

The CircleCI job run page also shows memory issue in Flax testing (https://app.circleci.com/pipelines/github/huggingface/transformers/45317/workflows/5bcb8b8a-776c-4c58-ad99-cf2700304c05/jobs/528556/resources)

To reproduce, see here for test_beam_search_generate.

Not very urgent, but we will have trouble once models are added. Could you have a look, please? Let me know if you need more information.

patrickvonplaten commented 2 years ago

Hey @ydshieh,

I'm a bit under water at the moment - I'll put the issue on my TODO-list, but I can't promise to find time to look into it very soon. This link: https://app.circleci.com/pipelines/github/huggingface/transformers/45317/workflows/5bcb8b8a-776c-4c58-ad99-cf2700304c05/jobs/528556/resources doesn't seem to show anything useful to me.

Also just to understand better, are the flax tests running on GPU or CPU?