ml-explore / mlx-data

Efficient framework-agnostic data loading
MIT License
361 stars 40 forks source link

Tokenization is extremely slow--am I doing something wrong? #22

Open andersonbcdefg opened 10 months ago

andersonbcdefg commented 10 months ago

Under what circumstances is MLX supposed to provide a speedup over sentencepiece? In a naive test with the same SPM .model file, I'm able to tokenize 1000 batches in 13 seconds with sentencepiece, and it takes over 5 minutes with MLX. Hardware is M2 Macbook Pro with 64GB unified memory. Is the CharTrie tokenization only useful when paired with key_transform? Are there plans to add a "tokenize_batch" with better parallelization/concurrency?

Code for reference:

class Tokenizer:
    def __init__(
        self, 
        model_path: str, 
        use_mlx: bool = False,
        mlx_tokenize_shortest: bool = False,
    ):
        assert Path(model_path).exists(), model_path
        self.use_mlx = use_mlx
        self.mlx_tokenize_shortest = mlx_tokenize_shortest
        self.spm_model = SentencePieceProcessor(model_file=model_path)
        assert self.spm_model.vocab_size() == self.spm_model.get_piece_size()
        if self.use_mlx:
            try:
                from mlx.data.core import Tokenizer as MLXTokenizer
                from mlx.data.tokenizer_helpers import read_trie_from_spm
            except ImportError:
                raise ImportError("Please install MLX to use MLX Tokenizer")
            trie, weights = read_trie_from_spm(model_path)
            try:
                self.mlx_model = MLXTokenizer(trie, trie_key_scores=weights)
                print(f"Loaded trie with {trie.num_keys()} keys.")
                assert self.spm_model.vocab_size() == trie.num_keys()
            except Exception as e:
                print(f"Unable to load trie into MLX Tokenizer: {e}. Using SentencePiece instead.")
                self.use_mlx = False

    @property
    def vocab_size(self) -> int:
        return self.spm_model.vocab_size()

    @property
    def bos_id(self) -> int:
        return self.spm_model.bos_id()

    @property
    def eos_id(self) -> int:
        return self.spm_model.eos_id()

    @property
    def pad_id(self) -> int:
        return self.spm_model.pad_id()

    @property
    def mask_id(self) -> int:
        # unknown kind of spiritually makes sense as mask
        return self.spm_model.unk_id()

    def __call__(
        self, 
        t: Union[str, list[str]],
        max_length: int,
        pack: bool = False,
        use_bos: bool = False,
        use_eos: bool = True,
        format: Literal["np", "mx"]= "mx",
        tokens_only: bool = False
    ) -> np.ndarray:
        if isinstance(t, str):
            t = [t]
        return self.encode_batch(t, max_length, pack, use_bos, use_eos, format)

    def _pad(self, tokens: ak.Array, max_length: int):
        tokens = ak.pad_none(tokens, target=max_length, axis=-1, clip=True)
        tokens = ak.fill_none(tokens, self.pad_id)
        return tokens

    def _unpad(self, tokens: ak.Array):
        tokens = ak.from_regular(tokens)
        is_pad = tokens == self.pad_id
        return tokens[~is_pad]

    def encode_single_mlx(
        self,
        text: str,
    ):
        if self.mlx_tokenize_shortest:
            tokens = self.mlx_model.tokenize_shortest(text)
        else:
            tokens = self.mlx_model.tokenize_rand(text)
        return tokens

    def encode_batch(
        self, 
        batch: list[str],
        max_length: int,
        pack: bool = False,
        use_bos: bool = False,
        use_eos: bool = True,
        tokens_only: bool = False
    ) -> np.ndarray:
        if self.use_mlx:
            tokens = ak.Array([self.encode_single_mlx(text) for text in batch])
        else:
            tokens = ak.Array(self.spm_model.encode(batch))
        if use_bos:
            tokens = ak.concatenate([ak.full_like(tokens[:, :1], self.bos_id), tokens], axis=1)
        if use_eos:
            tokens = ak.concatenate([tokens, ak.full_like(tokens[:, :1], self.eos_id)], axis=1)
        if pack:
            # pack into batches of max_length
            sequence_ids = ak.zeros_like(tokens) + range(len(tokens))
            tokens = ak.flatten(tokens, axis=None)
            sequence_ids = ak.flatten(sequence_ids, axis=None)
            n_to_truncate = len(tokens) % max_length
            tokens = np.asarray(tokens)[:-n_to_truncate].reshape(-1, max_length)
            sequence_ids = np.asarray(sequence_ids)[:-n_to_truncate].reshape(-1, max_length)
            if tokens_only:
                return tokens
            return {
                "input_ids": tokens,
                "sequence_ids": sequence_ids, # these are for block diagonal attention
                "attention_mask": np.zeros_like(tokens)
            }

        else:
            # keep sequences separate; pad or truncate to max_length
            tokens = np.array(self._pad(tokens, max_length))
            mask = np.where(tokens != self.pad_id, 0, float("-inf"))
            if tokens_only:
                return tokens
            return {
                "input_ids": tokens,
                "sequence_ids": np.zeros_like(tokens),
                "attention_mask": mask
            }

def test():
    tokenizer_file = "tokenizer.model"
    spm_tokenizer = Tokenizer(tokenizer_file, use_mlx=False)
    mlx_tokenizer = Tokenizer(tokenizer_file, use_mlx=True)

    random_texts = [
        """Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book. It has survived not only five centuries, but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum."""
    ] * 500
    import time
    start = time.time()
    for i in tqdm.tqdm(range(1_000)):
        spm_tokenizer(random_texts, 512, pack=True, use_bos=False, use_eos=True)
    print(f"spm: {time.time() - start:.3f}s")

    start = time.time()
    for i in tqdm.tqdm(range(1_000)):
        mlx_tokenizer(random_texts, 512, pack=True, use_bos=False, use_eos=True)
    print(f"mlx: {time.time() - start:.3f}s")

test()
angeloskath commented 9 months ago

Hi @andersonbcdefg, sorry for the extremely late reply.

So the tokenizer in MLX Data is actually quite fast when it comes to smallish documents. It optimizes over the whole passed document so it is quite slower when passed such a huge text like the one above (while it obviously doesn't make sense to check the whole graph).

For example the wikitext benchmark (https://github.com/ml-explore/mlx-data/blob/c1204bce12ce495add1ed68338543cb4b5c5a595/benchmarks/comparative/wikitext/mlx_data.py) on my Mac tokenizes a few millions of tokens per second which should be more than enough for any use case.

andersonbcdefg commented 9 months ago

Hmm, well the document in my example is only a few hundred characters. It's a batch of 500 of the same doc, but the doc is short so I'm not sure that optimizing over a large graph would explain the disparity in speed.

angeloskath commented 9 months ago

Oh sorry I kinda misunderstood the code snippet.

Having said that, I wouldn't say it is significantly slower than SPM. Running your benchmark with varying document size on my M2 air laptop I get the following comparison table with SPM

Doc length (chars) | Batch Tokens | MLX time / SPM time | MLX Tokens/s
-------------------+--------------+---------------------+-------------
              57   |      14336   |              5.40   |  1098130.11
             114   |      28672   |              8.25   |  1125880.12
             172   |      43520   |              9.30   |  1168599.07
             229   |      58368   |              9.89   |  1205666.60
             287   |      72704   |              9.63   |  1182750.15
             344   |      85504   |              11.0   |  1103767.26
             401   |     100864   |              10.9   |  1125994.54
             459   |     113664   |              10.7   |  1130119.24
             516   |     125952   |              10.9   |  1074331.83
             574   |     139264   |              10.5   |  1072074.781

Keep in mind that this is single core. So >1M tok/s per core I think is pretty reasonable for almost all use cases. We would of course appreciate PRs that improve that to reach the speed of SPM which is probably somewhere around 2M-3M tok/s per core on my machine.

andersonbcdefg commented 9 months ago

Yeah I hope that it's able to be sped up! A 10x difference in speed makes a big difference esp. for offline data processing type workflows (I understand 1M tok/s is fine if you're feeding an LLM in real time, but tokenization is also important for batch processing!)

angeloskath commented 9 months ago

Sure I understand, and we should work on it however, this is still single core. When using the following pipeline on my M2 air it is 3x slower than SPM

dset = (
          dx.stream_python_iterable(lambda: ({"doc": s.encode()} for s in random_texts))
          .tokenize("doc", trie)
          .prefetch(20, 4)
          .sliding_window("doc", 512, 512)
          .shape("doc", "length", 0)
          .batch(128)
          .prefetch(2, 1)
      )