joshspeagle / dynesty

Dynamic Nested Sampling package for computing Bayesian posteriors and evidences
https://dynesty.readthedocs.io/
MIT License
346 stars 76 forks source link

Tracking the performance of dynesty #432

Open segasai opened 1 year ago

segasai commented 1 year ago

This is an issue to track/discuss ways of tracking dynesty performance vs time to avoid issues.

I've wrote a quick script to produce timings of test vs time. And here's the output show relative timing for different tests. Each line is a different test averaged to have a median of 1.

timing It shows the impact of the recent dynamic sampling fix.

The script is below. The question is whether there is a way to track this automatically somehow.

import os
import numpy as np
import math
import xml
import glob
import re
import astropy.table as atpy

def fetch_commits():
    cmd = '( echo hash,date ; git log --pretty=format:%H,%aI |grep -e  b6c9d7cf488939e3a4bdab8c405ce298c8843482 -B 10000 ) > commit.list'
    os.system(cmd)

def get_commits():
    # commits = []
    T = atpy.Table().read('commit.list', format='ascii')
    # for commit in open('commit.list', 'r'):
    #    commit = commit.rstrip()
    #    commits.append(commit)
    # commits = np.array(commits)
    return T['hash'], T['date']

def extractor(fout):
    commits, dates = get_commits()
    ids = np.arange(len(commits))
    CD = {}
    for c, i, d in zip(commits, ids, dates):
        CD[c] = (i, d)

    with open(fout, 'w') as fp:
        print('commitid,commdate,commhash,testname1,testname2,duration',
              file=fp)
        for f in glob.glob('xmls/*xml'):
            D = open(f, 'r').read()

            M = re.match('.*failures="([0-9]+)".*', D)
            if M is not None:
                M = M.group(1)
            else:
                continue
            fail = int(M)

            if fail > 0:
                continue
            commit = f.split('.')[-2].split('_')[-1]
            if commit not in CD:
                continue
            D = xml.etree.ElementTree.parse(f)
            for el in D.findall('./testsuite/testcase'):
                el = dict(el.items())
                if 'classname' not in el:
                    continue
                print('%d,%s,%s,%s,%s,%s' %
                      (CD[commit][0], CD[commit][1], commit, el['classname'],
                       el['name'], el['time']),
                      file=fp)

def plotter():
    extractor('output.dat')
    res = {}
    T = atpy.Table().read('output.dat', format='csv')
    T = T[~T['testname1'].mask]

    import matplotlib.pyplot as plt
    plt.clf()
    for tdate, tcid, a, b, cdt in zip(T['commdate'], T['commitid'],
                                      T['testname1'].value,
                                      T['testname2'].value, T['duration']):
        if (a, b) not in res:
            res[a, b] = []
        res[a, b].append((tcid, cdt, tdate))
    # plt.plot([0])
    # plt.xlim(0, 900)
    plt.ylim(.2, 4)
    for k, v in res.items():
        X, Y, Z = np.array(v).T
        Y = Y.astype(np.float32)
        Z = np.array(Z, dtype=np.datetime64)
        if len(X) > 20 and np.median(Y) > 1:
            plt.semilogy(Z[np.argsort(Z)], (Y / np.median(Y))[np.argsort(Z)],
                         alpha=.2)
    plt.ylabel('Relative Time')

def generator():
    commits = get_commits()
    ids = np.arange(len(commits))
    maxpow = int(math.ceil(np.log(len(ids)) / np.log(2)))
    for pows in range(maxpow + 1, -1, -1):
        curN = 1 << pows
        print(curN)
        for curid in ids[::curN]:
            commit = commits[curid]
            print(curid)
            ofname = f'xmls/dat_{commit}.xml'
            if os.path.exists(ofname):
                continue
            commit = commits[curid]
            cmd = f'git checkout {commit}'
            os.system(cmd)
            cmd1 = f'''pytest  --junit-xml={ofname} -m 'not slow' --timeout 300 --dist loadgroup --durations 0 -n 36 tests/test*y'''
            os.system(cmd1)

if __name__ == '__main__':
    generator()