bmoscon / orderbook

A fast L2/L3 orderbook data structure, in C, for Python
GNU General Public License v3.0
258 stars 52 forks source link

use skiplist to improve 4x cpu performance #15

Closed xiangsf closed 2 years ago

xiangsf commented 2 years ago

Is your feature request related to a problem? Please describe. currently, this library use python SortedDict to store orderbook data。 when fetch coinbase、btc-usd data, in l2_book/l3_book callback functions, book.book.bids.index(0)/book.book.asks.index(0) will cost 40%-50% cpu

Describe the solution you'd like use skiplist to store orderbook, only use 10% cpu all test code in this zip file. skiplist.zip

Additional context test code1, python sorteddict use 43% cpu list

import os, sys, random
from skiplist import SkipList
from decimal import Decimal

import cryptofeed.types as cryptofeed_types
from cryptofeed import FeedHandler
from cryptofeed.defines import CANDLES, BID, ASK, L2_BOOK, L3_BOOK, LIQUIDATIONS, OPEN_INTEREST, PERPETUAL, TICKER, TRADES
from cryptofeed.exchanges import (Binance, Coinbase)

import logging

LOG_FORMAT = "%(asctime)s %(levelname)-2s %(process)d [%(filename)s:%(lineno)d:%(funcName)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)

async def book(book, receipt_timestamp):
    bids = book.book.bids
    asks = book.book.asks
    if len(bids) <= 0 or len(asks) <= 0:
        return
    asks.index(0)
    bids.index(0)
    price, size = bids.index(0)
    #logging.info(f"bid price:{price} size:{size}")

def do_test5_coinbase():
    f = FeedHandler()
    f.add_feed(Coinbase(subscription={L2_BOOK: ['BTC-USD']}, callbacks={L2_BOOK: book}))
    f.run()

test code2, skiplist use 9.6% cpu skiplist


class OrderBookBidAask:
    def __init__(self, max_depth=None, checksum_format=None, max_depth_strict=None):
        self.bids = SkipList(reverse=True)
        self.asks = SkipList()
        self.bid = self.bids
        self.ask = self.asks

    def __getitem__(self, key):
        if key in ['bid', 'bids']:
            return self.bids
        elif key in ['ask', 'asks']:
            return self.asks
        else:
            raise Exception

    def __delitem__(self, key):
        raise Exception

    def __setitem__(self, key, value):
        if key in ['bid', 'bids']:
            self.bids = value
        elif key in ['ask', 'asks']:
            self.asks = value
        else:
            raise Exception

class SkipListOrderBook:
    def __init__(self, exchange, symbol, bids=None, asks=None, max_depth=0, truncate=False, checksum_format=None):
        self.exchange = exchange
        self.symbol = symbol
        self.book = OrderBookBidAask(max_depth=max_depth, checksum_format=checksum_format, max_depth_strict=truncate)
        if bids:
            for k, v in bids.items():
                self.book.bids[k] = v
        if asks:
            for k, v in asks.items():
                self.book.asks[k] = v
        self.delta = None
        self.timestamp = None
        self.sequence_number = None
        self.checksum = None
        self.raw = None

    @staticmethod
    def from_dict(data: dict):
        ob = SkipListOrderBook(data['exchange'], data['symbol'], bids=data['book'][BID], asks=data['book'][ASK])
        ob.timestamp = data['timestamp']
        if 'delta' in data:
            ob.delta = data['delta']
        return ob

    def _delta(self, numeric_type) -> dict:
        return {
            BID: [tuple([numeric_type(v) if isinstance(v, Decimal) else v for v in value]) for value in self.delta[BID]],
            ASK: [tuple([numeric_type(v) if isinstance(v, Decimal) else v for v in value]) for value in self.delta[ASK]]
        }

    def to_dict(self, delta=False, numeric_type=None, none_to=False):
        raise Exception

    def __repr__(self):
        return f"exchange: {self.exchange} symbol: {self.symbol} book: {self.book} timestamp: {self.timestamp}"

    def __eq__(self, cmp):
        return self.exchange == cmp.exchange and self.symbol == cmp.symbol and self.delta == cmp.delta and self.timestamp == cmp.timestamp and self.sequence_number == cmp.sequence_number and self.checksum == cmp.checksum and self.book.to_dict() == cmp.book.to_dict()

    def __hash__(self):
        return hash(self.__repr__())

async def Coinbase_pair_level2_snapshot(self, msg: dict, timestamp: float):
    pair = self.exchange_symbol_to_std_symbol(msg['product_id'])
    bids = {Decimal(price): Decimal(amount) for price, amount in msg['bids']}
    asks = {Decimal(price): Decimal(amount) for price, amount in msg['asks']}
    if pair not in self._l2_book:
        self._l2_book[pair] = SkipListOrderBook(self.id, pair, max_depth=self.max_depth, bids=bids, asks=asks)
    else:
        for k, v in bids.items():
            self._l2_book[pair].book.bids[k] = v
        for k, v in asks.items():
            self._l2_book[pair].book.asks[k] = v
    await self.book_callback(L2_BOOK, self._l2_book[pair], timestamp, raw=msg)

async def book_skiplist(book, receipt_timestamp):
    bids = book.book.bids
    asks = book.book.asks
    if len(bids) <= 0 or len(asks) <= 0:
        return
    range_values = asks.index_slice(0, 101)
    for price, size in range_values:
        pass
    range_values = bids.index_slice(0, 101)
    for price, size in range_values:
        pass
    price, size = bids.index(0)
    #logging.info(f"bid price:{price} size:{size}")

def do_test6_skiplist():
    Coinbase._pair_level2_snapshot = Coinbase_pair_level2_snapshot

    f = FeedHandler()
    f.add_feed(Coinbase(subscription={L2_BOOK: ['BTC-USD']}, callbacks={L2_BOOK: book_skiplist}))
    f.run()

skiplist impl

#
# This file is part of Bluepass. Bluepass is Copyright (c) 2012-2014
# Geert Jansen.
#
# Bluepass is free software available under the GNU General Public License,
# version 3. See the file LICENSE distributed with this file for the exact
# licensing terms.
# this code is copy from https://github.com/geertj/pyskiplist

from __future__ import absolute_import, print_function

import os
import sys
import math
import random

class SkipList(object):
    """An indexable skip list.

    A SkipList provides an ordered sequence of key-value pairs. The list is
    always sorted on key and supports O(1) forward iteration. It has O(log N)
    time complexity for key lookup, pair insertion and pair removal anywhere in
    the list. The list also supports O(log N) element access by position.

    The keys of all pairs you add to the skiplist must be be comparable against
    each other, and define the ``<`` and ``<=`` operators.
    """

    UNSET = object()

    p = int((1<<31) / math.e)
    maxlevel = 20

    # Kudos to http://pythonsweetness.tumblr.com/post/45227295342 for some
    # useful tricks, including using a list for the nodes to save memory.

    # Use the built-in Mersenne Twister random number generator. It is more
    # appropriate than SystemRandom because we don't need cryptographically
    # secure random numbers, and we don't want to do a system call to read
    # /dev/urandom for each random number we need (every insertion needs a new
    # random number).

    _rnd = random.Random()
    _rnd.seed(os.urandom(16))

    __slots__ = ('_reverse', '_level', '_head', '_tail', '_path', '_distance')

    def __init__(self, reverse=False):
        self._reverse = reverse
        self._level = 1
        self._head = self._new_node(self.maxlevel, None, None)
        self._tail = self._new_node(self.maxlevel, None, None)
        for i in range(self.maxlevel):
            self._head[2+i] = self._tail
        self._path = [None] * self.maxlevel
        self._distance = [None] * self.maxlevel

    def _is_key_lt(self, key1, key2):
        if not self._reverse:
            if key1 < key2:
                return True
        else:
            if key1 > key2:
                return True
        return False

    def _is_key_lte(self, key1, key2):
        if not self._reverse:
            if key1 <= key2:
                return True
        else:
            if key1 >= key2:
                return True
        return False

    def _new_node(self, level, key, value):
        # Node layout: [key, value, next*LEVEL, skip?]
        # The "skip" element indicates how many nodes are skipped by the
        # highest level incoming link.
        if level == 1:
            return [key, value, None]
        else:
            return [key, value] + [None]*level + [0]

    def _random_level(self):
        # Exponential distribution as per Pugh's paper.
        l = 1
        maxlevel = min(self.maxlevel, self.level+1)
        while l < maxlevel and self._rnd.getrandbits(31) < self.p:
            l += 1
        return l

    def _create_node(self, key, value):
        # Create a new node, updating the list level if required.
        level = self._random_level()
        if level > self.level:
            self._tail[-1] = len(self)
            self._level = level
            self._path[level-1] = self._head
            self._distance[level-1] = 0
        return self._new_node(level, key, value)

    def _find_lt(self, key):
        # Find path to last node < key
        node = self._head
        distance = 0
        for i in reversed(range(self.level)):
            nnode = node[2+i]
            while nnode is not self._tail and self._is_key_lt(nnode[0], key):
                nnode, node = nnode[2+i], nnode
                distance += 1 if i == 0 else node[-1]
            self._path[i] = node
            self._distance[i] = distance

    def _find_lte(self, key):
        # Find path to last node <= key
        node = self._head
        distance = 0
        for i in reversed(range(self.level)):
            nnode = node[2+i]
            while nnode is not self._tail and self._is_key_lte(nnode[0], key):
                nnode, node = nnode[2+i], nnode
                distance += 1 if i == 0 else node[-1]
            self._path[i] = node
            self._distance[i] = distance

    def _find_pos(self, pos):
        # Create path to node at pos.
        node = self._head
        distance = 0
        for i in reversed(range(self.level)):
            nnode = node[2+i]
            ndistance = distance + (1 if i == 0 else nnode[-1])
            while nnode is not self._tail and ndistance <= pos:
                nnode, node, distance = nnode[2+i], nnode, ndistance
                ndistance += 1 if i == 0 else nnode[-1]
            self._path[i] = node
            self._distance[i] = distance

    def _insert(self, node):
        # Insert a node in the list. The _path and _distance must be set.
        path, distance = self._path, self._distance
        # Update pointers
        level = max(1, len(node) - 3)
        for i in range(level):
            node[2+i] = path[i][2+i]
            path[i][2+i] = node
        if level > 1:
            node[-1] = 1 + distance[0] - distance[level-1]
        # Update skip counts
        node = node[2]
        i = 2; j = min(len(node) - 3, self.level)
        while i <= self.level:
            while j < i:
                node = node[i]
                j = min(len(node) - 3, self.level)
            node[-1] -= distance[0] - distance[j-1] if j <= level else -1
            i = j+1

    def _remove(self, node):
        # Remove a node. The _path and _distance must be set.
        path, distance = self._path, self._distance
        level = max(1, len(node) - 3)
        for i in range(level):
            path[i][2+i] = node[2+i]
        # Update skip counts
        value = node[1]
        node = node[2]
        i = 2; j = min(len(node) - 3, self.level)
        while i <= self.level:
            while j < i:
                node = node[i]
                j = min(len(node) - 3, self.level)
            node[-1] += distance[0] - distance[j-1] if j <= level else -1
            i = j+1
        # Reduce level if last node on current level was removed
        while self.level > 1 and self._head[1+self.level] is self._tail:
            self._level -= 1
            self._tail[-1] += self._tail[-1] - len(self)
        return value

    # PUBLIC API ...

    @property
    def level(self):
        """The current level of the skip list."""
        return self._level

    def insert(self, key, value):
        """Insert a key-value pair in the list.

        The pair is inserted at the correct location so that the list remains
        sorted on *key*. If a pair with the same key is already in the list,
        then the pair is appended after all other pairs with that key.
        """
        self._find_lte(key)
        node = self._create_node(key, value)
        self._insert(node)

    def replace(self, key, value):
        """Replace the value of the first key-value pair with key *key*.

        If the key was not found, the pair is inserted.
        """
        self._find_lt(key)
        node = self._path[0][2]
        if node is self._tail or self._is_key_lt(key, node[0]):
            node = self._create_node(key, value)
            self._insert(node)
        else:
            node[1] = value

    def clear(self):
        """Remove all key-value pairs."""
        for i in range(self.maxlevel):
            self._head[2+i] = self._tail
            self._tail[-1] = 0
        self._level = 1

    def __len__(self):
        """Return the number of pairs in the list."""
        dist = 0
        idx = self.level + 1
        node = self._head[idx]
        while node is not self._tail:
            dist += node[-1] if idx > 2 else 1
            node = node[idx]
        dist += node[-1]
        return dist

    __bool__ = __nonzero__ = lambda self: len(self) > 0

    def __repr__(self):
        return type(self).__name__ + '((' + repr(list(self.items()))[1:-1] + '))'

    def items(self, start=None, stop=None):
        """Return an iterator yielding pairs.

        If *start* is specified, iteration starts at the first pair with a key
        that is larger than or equal to *start*. If not specified, iteration
        starts at the first pair in the list.

        If *stop* is specified, iteration stops at the last pair that is
        smaller than *stop*. If not specified, iteration end with the last pair
        in the list.
        """
        if start is None:
            node = self._head[2]
        else:
            self._find_lt(start)
            node = self._path[0][2]
        while node is not self._tail and (stop is None or node[0] < stop):
            yield (node[0], node[1])
            node = node[2]

    __iter__ = items

    def keys(self, start=None, stop=None):
        """Like :meth:`items` but returns only the keys."""
        return (item[0] for item in self.items(start, stop))

    def values(self, start=None, stop=None):
        """Like :meth:`items` but returns only the values."""
        return (item[1] for item in self.items(start, stop))

    def popitem(self):
        """Removes the first key-value pair and return it.

        This method raises a ``KeyError`` if the list is empty.
        """
        node = self._head[2]
        if node is self._tail:
            raise KeyError('list is empty')
        self._find_lt(node[0])
        self._remove(node)
        return (node[0], node[1])

    # BY KEY API ...

    def search(self, key, default=None):
        """Find the first key-value pair with key *key* and return its value.

        If the key was not found, return *default*. If no default was provided,
        return ``None``. This method never raises a ``KeyError``.
        """
        self._find_lt(key)
        node = self._path[0][2]
        if node is self._tail or self._is_key_lt(key, node[0]):
            return default
        return node[1]

    def remove(self, key):
        """Remove the first key-value pair with key *key*.

        If the key was not found, a ``KeyError`` is raised.
        """
        self._find_lt(key)
        node = self._path[0][2]
        if node is self._tail or self._is_key_lt(key, node[0]):
            raise KeyError('{!r} is not in list'.format(key))
        self._remove(node)

    def pop(self, key, default=UNSET):
        """Remove the first key-value pair with key *key*.

        If a pair was removed, return its value. Otherwise if *default* was
        provided, return *default*. Otherwise a ``KeyError`` is raised.
        """
        self._find_lt(key)
        node = self._path[0][2]
        if node is self._tail or self._is_key_lt(key, node[0]):
            if default is self.UNSET:
                raise KeyError('key {!r} not in list')
            return default
        self._remove(node)
        return node[1]

    def __contains__(self, key):
        """Return whether *key* is contained in the list."""
        self._find_lt(key)
        node = self._path[0][2]
        return node is not self._tail and not self._is_key_lt(key, node[0])

    def index(self, pos):
        size = len(self)
        if pos < 0:
            pos += size
        if not 0 <= pos < size:
            raise IndexError('list index out of range')
        self._find_pos(pos)
        node = self._path[0][2]
        return (node[0], node[1])

    def index_slice(self, start, stop):
        size = len(self)
        if start is None:
            start = 0
        elif start < 0:
            start += size
        if stop is None:
            stop = size
        elif stop < 0:
            stop += size
        self._find_pos(start)
        resulsts = []
        pos = start
        node = self._path[0][2]
        while node is not self._tail and pos < stop:
            resulsts.append((node[0], node[1]))
            #yield (node[0], node[1])
            node = node[2]
            pos += 1
        return resulsts

    def count(self, key):
        """Return the number of pairs with key *key*."""
        count = 0
        pos = self._find_pos(key, -1)
        if pos == -1:
            return count
        count += 1
        for i in range(pos+1, len(self)):
            if self[i][0] != key:
                break
            count += 1
        return count

    def __getitem__(self, key):
        val = self.search(key)
        return val

    def __delitem__(self, key):
        self.remove(key)

    def __setitem__(self, key, value):
        self.replace(key, value)
leftys commented 2 years ago

The author of skiplist actually writes on its github repo that pyskiplist is deprecated by SortedContainers, which are faster and have smaller memory footprint.

Does your implementation store only 20 price levels? That could lead to inconsistent orderbook after a few diff updates. It would also explain better performance.

bmoscon commented 2 years ago

closing as "wont do". If someone wants to open a PR with extensive tests and performance metrics, I might consider merging it, but I wont be doing this work