Open bionicles opened 3 years ago
second idea: how could we simulate adding money to the account at regular intervals (recurring investment) ? just want to model my own portfolio as closely as possible, and i'm buying a little bit every day, and reinvesting dividends.
No worries if it's a ton of work to add this stuff, but it could be fun to specify daily, weekly, monthly contributions, with dividend reinvestment, combined with algos doing their thing too, and see how it all adds up
Hi @bionicles, I will add dividends and recurring investments in the next release. With dividend reinvestment you would need to deal manually because it may be in conflict with your own orders (there is a cap of one order per timestamp in most methods). Since vectorbt is mainly targeted at crypto, I will also add stacking rewards.
dear @polakowo,
I'm still having fun tinkering with the library, are you interested in more ideas?
one issue I'm having is, numba is pretty finickey at invoking lists of functions passed as arguments, because they're 'heterogeneous lists' - plus a billion other little restrictions ... have you thought about trying a different JIT backend for vectorbt?
I love JAX from google, it has JIT, also has autograd (means we could do differentiable algos and backprop onto the parameters or some crazy shit), and full numpy rebuild that compiles into XLA and runs on the GPU if desired, a reproducible random number generator, and tree_util often saves many lines of code (tree_multimap). In my experience Jax is more often OK with higher-order programming, whereas numba crashes.
You can JIT more code with Jax, and integrate differentiable programming. Seems like a fascinating opportunity to integrate AI with Vectorbt imho. For neural net libraries in jax, https://github.com/google/flax, https://github.com/deepmind/dm-haiku, and a bunch more
It's your project obviously but, I just think Jax is a worthwhile alternative "compute accelerator backend" and one of my favorite libraries and could be an upgrade. I'd be down to tinker with it and share code, but I'm not familiar with the internals of vectorbt
What do you think?
https://jax.readthedocs.io/en/latest/ https://moocaholic.medium.com/jax-a13e83f49897
Hi @bionicles, thanks for the idea, I will give it a try to see how it compares to Numba.
Do you think jax.tree_util.tree_multimap could enable further broadcasting in the portfolio generation steps?
I wonder because you could potentially pass in 2 PyTrees (arbitrarily nested collections) with similar number of leaves for prices and targets in TargetPercent mode. Then you could build multiple differently-shaped portfolios
edit: oops, tree_map not multimap https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map.html#jax.tree_util.tree_map
prices = (qqq, eth_usd) # different number of candles per array
targets = (qqq_targets, eth_usd_targets)
portfolio = vbt.Portfolio.from_orders(
prices,
targets,
size_type=vbt.portfolio.enums.SizeType.TargetPercent,
init_cash=25000,
slippage=0.0006,
)
and it could just work, it'll just broadcast each pair of prices, targets, and it'd work on tuples, lists, dicts, it'll just return the same shape you pass in, so if you pass in prices and targets as matching-keyed dictionaries you could then get out a matching-keyed dictionary of portfolios for same keys
Another big massive benefit of using Jax instead of Numba would be, Jax already supports Python 3.10, so we could get the advantages of newer python versions quicker; like the Structural Pattern Matching in py310 looks awesome https://docs.python.org/3/whatsnew/3.10.html if there were a jax backend for vectorbt it could run on python 310 today, not sure how long it will take to port numba to 310
Hi @bionicles, I have little to no time to do my research on JAX as I’m dealing with more urgent features. If you have time and curiosity you can port some the vectorbt’s functions to JAX (such as any from vectorbt.generic.nb), I’d be eager to see the benchmarks!
no rush or expectation, I'm down to tinker. Is there any place in the codebase where we could make a toggle switch to toggle between jax and numba? Then i could just toggle it to jax mode and debug until it works
ex
# vectorbt.generic.accelerator
jit = jax.jit if env.ACCELERATOR == "jax" else numba.njit
# vectorbt.generic.fun
from .accelerator import jit
Also, is it possible to pass custom subplots? (I'm doing target percent size type and want to plot the target percentages w/ shared time axis underneath the cumulative returns)
There isn't such place as of now + you can't simply change the engine since many functions are tailored for numba such as those decorated with generated_jit. There will be a global decorator for jitting functions in the next release though, this is the closest you can get to an engine-agnostic design. I'm curious how jax compares to Numba; the examples I saw on the internet were mostly about vectorized ops with NumPy, but what I'm looking for are essentially loops.
It's all in the docs: https://vectorbt.dev/docs/portfolio/base.html#plots
Here are some benchmarks on generic tasks:
https://github.com/dionhaefner/pyhpc-benchmarks
Looks like it's about the same or a bit slower on CPU but enables GPU mode with potential 100-1000x speedup
Here's the API for loops and maps scans etc in jax
https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators
Which generic.nb
function is the most useful for me to port and benchmark first?
Looks promising, especially considering that a lot of vectorbt’s code can run in parallel. On the development branch, I already enabled most code to run in parallel using Numba’s parallel flag, but it results in a smaller speedup than by chunking and running each chunk with Dask, so there is some space for improvement.
You can try to port any function that takes another function as argument. Then this signal generation function, which runs two user-defined functions one after another to place entry and exit signals in-place. It would be interesting to see whether JAX allows modifying arrays. Finally, generation of record arrays here. We also need to test dynamic creation and use of named tuples. If all of this works in JAX, then there should be no issues porting simulation functions as well.
ok, here's a benchmark for returns_jax
vs returns_nb
on a ubuntu 20.04 rig running latest numba and numpy and jax on a 12 core intel cpu with 64gb ram and an nvidia gtx 1070 ti 11gb (you can get much more powerful graphics cards than this, not to mention TPU backend)
the timings for jax are suspiciously linear, but i am using block_until_ready()
so it definitely returns a result and i also print out the error between the jax result and the numba result...not sure if i fucked something up, but this goes to show the power of jax
im using jax.vmap
to make a 2d returns function which just vectorized-maps the 1d return function over the 2d returns function
calculate (2048, 4) returns (size 8192)
jax:
10.6 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
nb:
184 µs ± 3.29 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
mean absolute error between jax and numba returns:
jnp.abs(nb_test - jax_test).mean() =
[DeviceArray(0.00156941, dtype=float32), DeviceArray(0.00181136, dtype=float32), DeviceArray(0.00315392, dtype=float32), DeviceArray(0.00239617, dtype=float32), DeviceArray(0.00507746, dtype=float32)]
calculate (2048, 8) returns (size 16384)
jax:
10.7 ms ± 223 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
nb:
324 µs ± 3.67 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
mean absolute error between jax and numba returns:
jnp.abs(nb_test - jax_test).mean() =
[DeviceArray(0.00266934, dtype=float32), DeviceArray(0.00295464, dtype=float32), DeviceArray(0.00394731, dtype=float32), DeviceArray(0.0022995, dtype=float32), DeviceArray(0.01856399, dtype=float32)]
calculate (2048, 16) returns (size 32768)
jax:
10.8 ms ± 187 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
nb:
643 µs ± 5.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
mean absolute error between jax and numba returns:
jnp.abs(nb_test - jax_test).mean() =
[DeviceArray(0.00224286, dtype=float32), DeviceArray(0.00581381, dtype=float32), DeviceArray(0.00162147, dtype=float32), DeviceArray(0.0025181, dtype=float32), DeviceArray(0.00343394, dtype=float32)]
calculate (2048, 32) returns (size 65536)
jax:
10.7 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
nb:
1.41 ms ± 11 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
mean absolute error between jax and numba returns:
jnp.abs(nb_test - jax_test).mean() =
[DeviceArray(0.00663906, dtype=float32), DeviceArray(0.00350164, dtype=float32), DeviceArray(0.0018475, dtype=float32), DeviceArray(0.00202276, dtype=float32), DeviceArray(0.00378822, dtype=float32)]
calculate (2048, 64) returns (size 131072)
jax:
10.6 ms ± 124 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
nb:
3.06 ms ± 19 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
mean absolute error between jax and numba returns:
jnp.abs(nb_test - jax_test).mean() =
[DeviceArray(0.00477037, dtype=float32), DeviceArray(0.00977544, dtype=float32), DeviceArray(0.00477652, dtype=float32), DeviceArray(0.00462031, dtype=float32), DeviceArray(0.01092727, dtype=float32)]
calculate (2048, 128) returns (size 262144)
jax:
10.6 ms ± 149 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
nb:
5.51 ms ± 27.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
mean absolute error between jax and numba returns:
jnp.abs(nb_test - jax_test).mean() =
[DeviceArray(0.01863427, dtype=float32), DeviceArray(0.0039546, dtype=float32), DeviceArray(0.0088834, dtype=float32), DeviceArray(0.01100635, dtype=float32), DeviceArray(0.0041334, dtype=float32)]
calculate (2048, 256) returns (size 524288)
jax:
10.7 ms ± 116 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
nb:
11.2 ms ± 72.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
mean absolute error between jax and numba returns:
jnp.abs(nb_test - jax_test).mean() =
[DeviceArray(0.00524056, dtype=float32), DeviceArray(0.00535696, dtype=float32), DeviceArray(0.0193589, dtype=float32), DeviceArray(0.0171509, dtype=float32), DeviceArray(0.0048359, dtype=float32)]
calculate (2048, 512) returns (size 1048576)
jax:
10.7 ms ± 92.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
nb:
36.1 ms ± 155 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
mean absolute error between jax and numba returns:
jnp.abs(nb_test - jax_test).mean() =
[DeviceArray(0.00607103, dtype=float32), DeviceArray(0.00557117, dtype=float32), DeviceArray(0.00548947, dtype=float32), DeviceArray(0.010967, dtype=float32), DeviceArray(0.01028708, dtype=float32)]
calculate (2048, 1024) returns (size 2097152)
jax:
11.2 ms ± 192 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
nb:
91.5 ms ± 313 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
mean absolute error between jax and numba returns:
jnp.abs(nb_test - jax_test).mean() =
[DeviceArray(0.0061493, dtype=float32), DeviceArray(0.00508331, dtype=float32), DeviceArray(0.00702058, dtype=float32), DeviceArray(0.00534153, dtype=float32), DeviceArray(0.00589914, dtype=float32)]
calculate (2048, 2048) returns (size 4194304)
jax:
12.2 ms ± 88.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
nb:
170 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
mean absolute error between jax and numba returns:
jnp.abs(nb_test - jax_test).mean() =
[DeviceArray(0.00593122, dtype=float32), DeviceArray(0.0050958, dtype=float32), DeviceArray(0.00858048, dtype=float32), DeviceArray(0.00502566, dtype=float32), DeviceArray(0.01482371, dtype=float32)]
im using jupyter which apparently doesnt work here. one code block per notebook cell:
from jax import devices, random, jit, vmap, numpy as jnp
from vectorbt import _typing as tp
from numba import njit
from jax import random
import numpy as np
import itertools
from functools import partial
import pandas as pd
import random as pyrandom
print(devices(backend="gpu"))
gpu = devices(backend="gpu")[0]
def _returns_1d_jax(value: tp.Array1d, init_value: float):
Y = jnp.concatenate((jnp.array([init_value]), value))
return jnp.divide(
jnp.diff(Y),
Y[:1]
) * jnp.sign(Y[:1])
returns_1d_jax = jit(_returns_1d_jax, device=gpu)
returns_jax = lambda v, i: vmap(partial(returns_1d_jax, init_value=i), in_axes=1, out_axes=1)(v)
@njit(cache=True)
def get_return_nb(input_value: float, output_value: float) -> float:
"""Calculate return from input and output value."""
if input_value == 0:
if output_value == 0:
return 0.
return np.inf * np.sign(output_value)
return_value = (output_value - input_value) / input_value
if input_value < 0:
return_value *= -1
return return_value
@njit(cache=True)
def returns_1d_nb(value: tp.Array1d, init_value: float) -> tp.Array1d:
"""Calculate returns from value."""
out = np.empty(value.shape, dtype=np.float_)
input_value = init_value
for i in range(out.shape[0]):
output_value = value[i]
out[i] = get_return_nb(input_value, output_value)
input_value = output_value
return out
@njit(cache=True)
def returns_nb(value: tp.Array2d, init_value: tp.Array1d) -> tp.Array2d:
"""2-dim version of `returns_1d_nb`."""
out = np.empty(value.shape, dtype=np.float_)
for col in range(out.shape[1]):
out[:, col] = returns_1d_nb(value[:, col], init_value[col])
return out
# these functions just help make random test data, i made one for each of jax and numpy
def random_walk_jax(start, scale, steps, seed, n_portfolio=None):
key = random.PRNGKey(seed)
shape = (steps - 1, n_portfolio) if n_portfolio else (steps - 1, )
noise = random.normal(key, shape=shape) * scale
noise = jnp.insert(noise, 0, start, axis=0)
walk = jnp.cumsum(noise, axis=0)
return walk
# jax_data = random_walk_jax(100, 1, 100000, 1)
def random_walk_nb(start, scale, steps, seed, n_portfolio=None):
rng = np.random.default_rng(seed)
shape = (steps - 1, n_portfolio) if n_portfolio else (steps - 1, )
noise = rng.normal(size=shape) * scale
noise = np.insert(noise, 0, start, axis=0)
walk = np.cumsum(noise, axis=0)
return walk
# nb_data = random_walk_nb(100, 1, 100000, 1)
# make sure the 1d jax thing works right
arr = jnp.array([100, 110, 100, 120])
numerator = jnp.diff(arr)
print("differences", numerator)
denominator = arr[:-1]
print("denominators", denominator)
returns = numerator / denominator
print("returns", returns)
%timeit returns_1d_jax(jax_data, 1.0)
nb_time = %timeit -o returns_1d_nb(nb_data, 1.0)
n_candles = [2048]
n_portfolios = [2 ** i for i in range(2, 12)]
n_loops = 5
timings = {"index": [], "jax": [], "nb": []}
testing = True
timing = True
for n_candle, n_portfolio in itertools.product(n_candles, n_portfolios):
size = n_candle*n_portfolio
print(f"\ncalculate ({n_candle}, {n_portfolio}) returns (size {size})\n")
timings["index"].append(n_portfolio)
def jax_fun(walk=None):
if walk is None:
seed = pyrandom.randint(0, 1000)
walk = random_walk_jax(1, 0.01, n_candle, seed, n_portfolio=n_portfolio)
returns = returns_jax(walk, 1.0).block_until_ready()
return returns
one = np.ones((n_portfolio,))
def nb_fun(walk=None):
if walk is None:
seed = pyrandom.randint(0, 1000)
walk = random_walk_nb(1, 0.01, n_candle, seed, n_portfolio=n_portfolio)
returns = returns_nb(walk, one)
return returns
if timing:
print("jax:")
jax_timings = %timeit -o jax_fun()
timings["jax"].append(jax_timings.average)
print("nb:")
nb_timings = %timeit -o nb_fun()
timings["nb"].append(nb_timings.average)
if testing:
print("mean absolute error between jax and numba returns:")
print("jnp.abs(nb_test - jax_test).mean() =")
maes = []
for i in range(n_loops):
seed = pyrandom.randint(0, 1000)
walk = random_walk_jax(1, 0.01, n_candle, seed, n_portfolio=n_portfolio)
np_walk = np.array(walk)
jax_test = jax_fun(walk)
nb_test = nb_fun(np_walk)
mae = jnp.abs(nb_test - jax_test).mean()
maes.append(mae)
print(maes)
if timing:
df = pd.DataFrame.from_dict(timings)
df.index = df["index"]
df = df.drop(columns="index")
df.plot(title="vectorbt returns accelerator benchmark", xlabel="n_portfolios", ylabel="time (lower is better)")
@bionicles looks great! JAX seems to be a solid alternative to Numba especially for vectorizable functions and large arrays, even with that overhead of 10ms per call that your benchmarks show. I'm not sure whether this overhead comes due to transferring the data to the device or due to some JAX-related warmup, need to do more tests by myself. The next release will come out with a global decorator for registration of jitted functions, which will be a good starting point for adding JAX functionality in subsequent releases.
hi, thanks for the cool project, im using it for a little ML experiment, just curious, would you be openminded to add dividends and dividend reinvestment [DRIP] to the backtest results? (or if i'm dumb and they're working now, please let me know how to pass in the dividends data!)