quantopian / pyfolio

Portfolio and risk analytics in Python
https://quantopian.github.io/pyfolio
Apache License 2.0
5.7k stars 1.78k forks source link

error in extract_round_trips #602

Open wukan1986 opened 5 years ago

wukan1986 commented 5 years ago

Problem Description

Please provide a minimal, self-contained, and reproducible example:

import pandas as pd
import pyfolio as pf
from io import StringIO

content = """
index,symbol,amount,price
2019-03-21 09:00:00,ABCD,-1,100
2019-03-21 09:10:00,ABCD,1,200
2019-03-21 09:20:00,ABCD,1,200
2019-03-21 09:30:00,ABCD,-1,100
"""
df = pd.read_csv(StringIO(content), index_col=0, parse_dates=True)
rts = pf.round_trips.extract_round_trips(df)
print(rts)

"""
             close_dt   long             open_dt   pnl  rt_returns symbol duration
0 2019-03-21 09:10:00  False 2019-03-21 09:00:00 -15.0   -0.150000   ABCD 00:10:00
1 2019-03-21 09:30:00   True 2019-03-21 09:10:00  15.0    0.130435   ABCD 00:20:00

"""

Please provide the full traceback:

Please provide any additional information below:

row 1, the open_dt should be 2019-03-21 09:20:00 not 2019-03-21 09:10:00

Versions

wukan1986 commented 5 years ago

my new code:

import pandas as pd
from io import StringIO
from math import copysign
from collections import deque
import numpy as np
import warnings
import itertools

pd.set_option('display.width', 1000)

def extract_round_trips(transactions):
    roundtrips = []
    ticket = itertools.count(1)

    for sym, trans_sym in transactions.groupby('symbol'):
        trans_sym = trans_sym.sort_index()
        price_stack = deque()
        dt_stack = deque()
        oid_stack = deque()
        trans_sym['signed_price'] = trans_sym.price * \
                                    np.sign(trans_sym.amount)
        trans_sym['abs_amount'] = trans_sym.amount.abs().astype(int)
        for dt, t in trans_sym.iterrows():
            if t.price < 0:
                warnings.warn('Negative price detected, ignoring for'
                              'round-trip.')
                continue

            indiv_prices = [t.signed_price] * t.abs_amount
            if (len(price_stack) == 0) or \
                    (copysign(1, price_stack[-1][-1]) == copysign(1, t.amount)):
                price_stack.append(deque(indiv_prices))
                dt_stack.append(deque([dt] * len(indiv_prices)))
                oid_stack.append(deque([next(ticket)] * len(indiv_prices)))
            else:
                # Close round-trip
                pnl = 0
                invested = 0
                cur_open_dts = []
                size = 0

                for idx, price in enumerate(indiv_prices):
                    if len(price_stack) != 0 and \
                            (copysign(1, price_stack[-1][-1]) != copysign(1, price)):

                        _price_stack = price_stack[0]
                        _dt_stack = dt_stack[0]
                        _oid_stack = oid_stack[0]

                        prev_price = _price_stack.popleft()
                        prev_dt = _dt_stack.popleft()
                        prev_oid = _oid_stack.popleft()

                        pnl += -(price + prev_price)
                        cur_open_dts.append(prev_dt)
                        invested += abs(prev_price)
                        oid = prev_oid
                        size += 1

                        if len(_price_stack) == 0:
                            price_stack.popleft()
                            dt_stack.popleft()
                            oid_stack.popleft()

                        if (len(_price_stack) == 0) or (idx == len(indiv_prices) - 1):
                            roundtrips.append({'pnl': pnl,
                                               'open_dt': cur_open_dts[0],
                                               'close_dt': dt,
                                               'long': price < 0,
                                               'rt_returns': pnl / invested,
                                               'symbol': sym,
                                               'oid': oid,
                                               'open_price': abs(prev_price),
                                               'close_price': abs(price),
                                               'size': size
                                               })

                            pnl = 0
                            invested = 0
                            cur_open_dts = []
                            size = 0
                    else:
                        # Push additional stock-prices onto stack
                        price_stack.append(deque([price]))
                        dt_stack.append(deque([dt]))
                        oid_stack.append(deque([next(ticket)]))

    roundtrips = pd.DataFrame(roundtrips)

    roundtrips['duration'] = roundtrips['close_dt'].sub(roundtrips['open_dt'])

    return roundtrips

if __name__ == '__main__':
    content = '''
    index,symbol,amount,price
    2019-03-21 09:00:00,ABCD,2,100
    2019-03-21 09:10:00,ABCD,2,110
    2019-03-21 09:20:00,ABCD,-3,120
    2019-03-21 09:30:00,ABCD,-2,130
    2019-03-21 09:40:00,ABCD,1,140
    '''

    '''
             close_dt  close_price   long  oid             open_dt  open_price  pnl  rt_returns  size symbol duration
0 2019-03-21 09:20:00          120   True    1 2019-03-21 09:00:00         100   40    0.200000     2   ABCD 00:20:00
1 2019-03-21 09:20:00          120   True    2 2019-03-21 09:10:00         110   10    0.090909     1   ABCD 00:10:00
2 2019-03-21 09:30:00          130   True    2 2019-03-21 09:10:00         110   20    0.181818     1   ABCD 00:20:00
3 2019-03-21 09:40:00          140  False    3 2019-03-21 09:30:00         130  -10   -0.076923     1   ABCD 00:10:00
    '''

    df = pd.read_csv(StringIO(content), index_col=0, parse_dates=True)

    rts = extract_round_trips(df)
    print(df)
    print(rts)