PGPortfolio: Policy Gradient Portfolio, the source code of "A Deep Reinforcement Learning Framework for the Financial Portfolio Management Problem"(https://arxiv.org/pdf/1706.10059.pdf).
Different asset classes and time index

silakkamarkkina commented 6 years ago


Thank you for your great work and for the original paper, which has been a great inspiration in choosing a thesis topic as a master's student. I have been working on changing the code to work on FX assets.

I have changed the globaldatamatrix.py and especially the get_global_panel function to match my daily OHLC FX data without data for volume (which would be hard to extrapolate accurately from even a number of FX datasets including volume).

However I have come up with a problem regarding the creation of the panel in get_global_panel: the time_index that is created at line 71 is far longer than my actual data is, producing an error at line 117: panel.loc[feature, coin, serial_data.index] = serial_data.squeeze().

The error:

From what I have understood the problem is that the len(serial_data) is some 700 obs less than time_index. I have gathered my data with the following:

def preprocess(sym, data_source, startD, endD):
        df = web.DataReader(sym, data_source, startD, endD)[['Open','High','Low','Close']]
        df.index = df.index.astype(np.int64) // 10**9
        if sym[3] != '=':
            df[['Open','High','Low','Close']] = 1/df[['Open','High','Low','Close']]
        coin = pd.Series(data = [sym[0:3] for number in range(len(df))])
        df['Coin'] = coin.values
        cols = df.columns.tolist
        cols = ['Coin','Open','High','Low','Close']
        df = df[cols]
        outN = './Thesis/Data/FX/'+str(sym[0:3])+'.csv'
        df.to_csv(outN, sep=',', decimal='.')
        print('Error for: ' + sym)

for sym in symbols:
    preprocess(sym, data_source, startD, endD)

I then merge the csv's and import to .db using DB Browser such that they are in the same form as the original cryptocurrency database. Below is the version that I have edited from the globaldatamatrix.py. Do you have recommendations on how to continue? Thank you already in advance, I am relatively new to programming.

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

from pgportfolio.marketdata.coinlist import CoinList
import numpy as np
import pandas as pd
from pgportfolio.tools.data import panel_fillna
from pgportfolio.constants import *
import sqlite3
from datetime import datetime
import logging

class HistoryManager:
    # if offline ,the coin_list could be None
    # NOTE: return of the sqlite results is a list of tuples, each tuple is a row
    def __init__(self, coin_number, end, volume_average_days=1, volume_forward=0, online=True):
        self.__storage_period = FIVE_MINUTES  # keep this as 300
        self._coin_number = coin_number
        self._online = online
        if self._online:
            self._coin_list = CoinList(end, volume_average_days, volume_forward)
        self.__volume_forward = volume_forward
        self.__volume_average_days = volume_average_days
        self.__coins = None

    def coins(self):
        return self.__coins

    def initialize_db(self):
        with sqlite3.connect(DATABASE_DIR) as connection:
            cursor = connection.cursor()
            cursor.execute('CREATE TABLE IF NOT EXISTS History (date INTEGER,'
                           ' coin TEXT, high REAL, low REAL,'
                           ' open REAL, close REAL, volume REAL, '
                           ' quoteVolume REAL, weightedAverage REAL,'
                           'PRIMARY KEY (date, coin));')

    def get_global_data_matrix(self, start, end, period=86400, features=('close',)):
        :return a numpy ndarray whose axis is [feature, coin, time]
        return self.get_global_panel(start, end, period, features).values

    def get_global_panel(self, start, end, period=86400, features=('close',)):
        :param start/end: linux timestamp in seconds
        :param period: time interval of each data access point
        :param features: tuple or list of the feature names
        :return a panel, [feature, coin, time]
        start = int(start - (start%period))
        end = int(end - (end%period))
        coins = self.select_coins(start=start,
        self.__coins = coins
        for coin in coins:
            self.update_data(start, end, coin)

        if len(coins)!=self._coin_number:
            raise ValueError("the length of selected coins %d is not equal to expected %d"
                             % (len(coins), self._coin_number))

        logging.info("feature type list is %s" % str(features))

        time_index = pd.to_datetime(list(range(start, end+1, period)),unit='s')
        panel = pd.Panel(items=features, major_axis=coins, minor_axis=time_index, dtype=np.float32)

        connection = sqlite3.connect(DATABASE_DIR)
            for row_number, coin in enumerate(coins):
                for feature in features:
                    # NOTE: transform the start date to end date
                    if feature == "close":
                        sql = ("SELECT date+{period} AS date_norm, close FROM History WHERE"
                               " date_norm>={start} and date_norm<={end}" 
                               " and date_norm%{period}=0 and coin=\"{coin}\"".format(
                               start=start, end=end, period=period, coin=coin))
                    elif feature == "open":
                        sql = ("SELECT date+{period} AS date_norm, open FROM History WHERE"
                               " date_norm>={start} and date_norm<={end}" 
                               " and date_norm%{period}=0 and coin=\"{coin}\"".format(
                               start=start, end=end, period=period, coin=coin))
                    elif feature == "volume":
                        sql = ("SELECT date_norm, SUM(volume)"+
                               " FROM (SELECT date+{period}-(date%{period}) "
                               "AS date_norm, volume, coin FROM History)"
                               " WHERE date_norm>={start} and date_norm<={end} and coin=\"{coin}\""
                               " GROUP BY date_norm".format(
                    elif feature == "high":
                        sql = ("SELECT date_norm, MAX(high)" +
                               " FROM (SELECT date+{period}-(date%{period})"
                               " AS date_norm, high, coin FROM History)"
                               " WHERE date_norm>={start} and date_norm<={end} and coin=\"{coin}\""
                               " GROUP BY date_norm".format(
                    elif feature == "low":
                        sql = ("SELECT date_norm, MIN(low)" +
                                " FROM (SELECT date+{period}-(date%{period})"
                                " AS date_norm, low, coin FROM History)"
                                " WHERE date_norm>={start} and date_norm<={end} and coin=\"{coin}\""
                                " GROUP BY date_norm".format(
                        msg = ("The feature %s is not supported" % feature)
                        raise ValueError(msg)
                    serial_data = pd.read_sql_query(sql, con=connection,
                    panel.loc[feature, coin, serial_data.index] = serial_data.squeeze()
                    panel = panel_fillna(panel, "both")
        return panel

    # select top coin_number of coins by volume from start to end
    def select_coins(self, start, end):
        if not self._online:
            logging.info("select coins offline from %s to %s" % (datetime.fromtimestamp(start).strftime('%Y-%m-%d %H:%M'),
                                                                    datetime.fromtimestamp(end).strftime('%Y-%m-%d %H:%M')))
            connection = sqlite3.connect(DATABASE_DIR)
                cursor.execute('SELECT coin FROM History WHERE'
                               ' date>=? and date<=? GROUP BY coin',
                               (int(start), int(end)))
                coins_tuples = cursor.fetchall()

                if len(coins_tuples)!=self._coin_number:
                    logging.error("the sqlite error happend")
            coins = []
            for tuple in coins_tuples:
            coins = list(self._coin_list.topNVolume(n=self._coin_number).index)
        logging.debug("Selected coins are: "+str(coins))
        return coins

    def __checkperiod(self, period):
        if period == FIVE_MINUTES:
        elif period == FIFTEEN_MINUTES:
        elif period == HALF_HOUR:
        elif period == TWO_HOUR:
        elif period == FOUR_HOUR:
        elif period == DAY:
            raise ValueError('peroid has to be 5min, 15min, 30min, 2hr, 4hr, or a day')

    # add new history data into the database
    def update_data(self, start, end, coin):
        connection = sqlite3.connect(DATABASE_DIR)
            cursor = connection.cursor()
            min_date = cursor.execute('SELECT MIN(date) FROM History WHERE coin=?;', (coin,)).fetchall()[0][0]
            max_date = cursor.execute('SELECT MAX(date) FROM History WHERE coin=?;', (coin,)).fetchall()[0][0]

            if min_date==None or max_date==None:
                self.__fill_data(start, end, coin, cursor)
                if max_date+10*self.__storage_period<end:
                    if not self._online:
                        raise Exception("Have to be online")
                    self.__fill_data(max_date + self.__storage_period, end, coin, cursor)
                if min_date>start and self._online:
                    self.__fill_data(start, min_date - self.__storage_period-1, coin, cursor)

            # if there is no data

    def __fill_data(self, start, end, coin, cursor):
        chart = self._coin_list.get_chart_until_success(
            pair=self._coin_list.allActiveCoins.at[coin, 'pair'],
        logging.info("fill %s data from %s to %s"%(coin, datetime.fromtimestamp(start).strftime('%Y-%m-%d %H:%M'),
                                            datetime.fromtimestamp(end).strftime('%Y-%m-%d %H:%M')))
        for c in chart:
            if c["date"] > 0:
                if c['weightedAverage'] == 0:
                    weightedAverage = c['close']
                    weightedAverage = c['weightedAverage']

                #NOTE here the USDT is in reversed order
                if 'reversed_' in coin:
                    cursor.execute('INSERT INTO History VALUES (?,?,?,?,?,?,?,?,?)',
                    cursor.execute('INSERT INTO History VALUES (?,?,?,?,?,?,?,?,?)',
silakkamarkkina commented 6 years ago

Problem with NaNs in my own downloading of data, not an issue related to the code here.